Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/google/adk/evaluation/eval_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,13 @@ class MatchType(Enum):
),
)

ignore_args: bool = Field(
default=False,
description=(
"If True, only tool names are compared; arguments are ignored."
),
)

@field_validator("match_type", mode="before")
@classmethod
def _coerce_match_type(cls, value: object) -> object:
Expand Down
20 changes: 11 additions & 9 deletions src/google/adk/evaluation/trajectory_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,17 @@ def __init__(
)
self._threshold = criterion.threshold
self._match_type = criterion.match_type
self._ignore_args = criterion.ignore_args
except ValidationError as e:
expected_criterion_type_error = ValueError(
f"`{eval_metric.metric_name}` metric expects a criterion of type"
f" `{TrajectoryEvaluator.criterion_type}`."
)
raise expected_criterion_type_error from e
elif eval_metric:
self._threshold = eval_metric.threshold
self._match_type = ToolTrajectoryCriterion.MatchType.EXACT
else:
self._threshold = threshold
self._threshold = eval_metric.threshold if eval_metric else threshold
self._match_type = ToolTrajectoryCriterion.MatchType.EXACT
self._ignore_args = False

@override
def evaluate_invocations(
Expand Down Expand Up @@ -191,9 +190,8 @@ def _are_tool_calls_in_order_match(
try:
current_expected = next(expected_it)
for actual in actual_tool_calls:
if (
actual.name == current_expected.name
and actual.args == current_expected.args
if actual.name == current_expected.name and (
self._ignore_args or actual.args == current_expected.args
):
current_expected = next(expected_it)
except StopIteration:
Expand Down Expand Up @@ -229,7 +227,9 @@ def _are_tool_calls_any_order_match(
for expected in expected_tool_calls:
found = False
for i, actual in enumerate(actual_tool_calls_copy):
if actual.name == expected.name and actual.args == expected.args:
if actual.name == expected.name and (
self._ignore_args or actual.args == expected.args
):
actual_tool_calls_copy.pop(i)
found = True
break
Expand Down Expand Up @@ -260,7 +260,9 @@ def _are_tool_calls_exact_match(
return False

for actual, expected in zip(actual_tool_calls, expected_tool_calls):
if actual.name != expected.name or actual.args != expected.args:
if actual.name != expected.name or (
not self._ignore_args and actual.args != expected.args
):
return False

return True
Expand Down
240 changes: 240 additions & 0 deletions tests/unittests/evaluation/test_trajectory_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,3 +462,243 @@ def test_evaluate_invocations_no_invocations(evaluator: TrajectoryEvaluator):
assert result.overall_score is None
assert result.overall_eval_status == EvalStatus.NOT_EVALUATED
assert not result.per_invocation_results


# --- ignore_args tests ---


def _make_ignore_args_evaluator(
match_type: ToolTrajectoryCriterion.MatchType,
) -> TrajectoryEvaluator:
return TrajectoryEvaluator(
eval_metric=EvalMetric(
metric_name=PrebuiltMetrics.TOOL_TRAJECTORY_AVG_SCORE.value,
criterion=ToolTrajectoryCriterion(
threshold=0.5,
match_type=match_type,
ignore_args=True,
),
)
)


_EXACT = ToolTrajectoryCriterion.MatchType.EXACT
_IN_ORDER = ToolTrajectoryCriterion.MatchType.IN_ORDER
_ANY_ORDER = ToolTrajectoryCriterion.MatchType.ANY_ORDER


@pytest.mark.parametrize(
("match_type", "actual_tools", "expected_tools", "expected_score"),
[
# EXACT: different args, same names -> pass
(
_EXACT,
[
genai_types.FunctionCall(name="t1", args={"a": 1}),
genai_types.FunctionCall(name="t2", args={"b": 2}),
],
[
genai_types.FunctionCall(name="t1", args={"x": 99}),
genai_types.FunctionCall(name="t2", args={"y": 100}),
],
1.0,
),
# EXACT: different names -> fail
(
_EXACT,
[genai_types.FunctionCall(name="t1", args={})],
[genai_types.FunctionCall(name="t2", args={})],
0.0,
),
# EXACT: different tool count -> fail
(
_EXACT,
[
genai_types.FunctionCall(name="t1", args={}),
genai_types.FunctionCall(name="t2", args={}),
],
[genai_types.FunctionCall(name="t1", args={})],
0.0,
),
# EXACT: empty lists -> pass
(
_EXACT,
[],
[],
1.0,
),
# IN_ORDER: different args with extra tools -> pass
(
_IN_ORDER,
[
genai_types.FunctionCall(name="t1", args={"a": 1}),
genai_types.FunctionCall(name="extra", args={}),
genai_types.FunctionCall(name="t2", args={"b": 2}),
],
[
genai_types.FunctionCall(name="t1", args={"x": 99}),
genai_types.FunctionCall(name="t2", args={"y": 100}),
],
1.0,
),
# IN_ORDER: wrong order -> fail
(
_IN_ORDER,
[
genai_types.FunctionCall(name="t2", args={}),
genai_types.FunctionCall(name="t1", args={}),
],
[
genai_types.FunctionCall(name="t1", args={}),
genai_types.FunctionCall(name="t2", args={}),
],
0.0,
),
# IN_ORDER: missing tool -> fail
(
_IN_ORDER,
[genai_types.FunctionCall(name="t1", args={})],
[
genai_types.FunctionCall(name="t1", args={}),
genai_types.FunctionCall(name="t2", args={}),
],
0.0,
),
# ANY_ORDER: different args, swapped order -> pass
(
_ANY_ORDER,
[
genai_types.FunctionCall(name="t2", args={"b": 2}),
genai_types.FunctionCall(name="t1", args={"a": 1}),
],
[
genai_types.FunctionCall(name="t1", args={"x": 99}),
genai_types.FunctionCall(name="t2", args={"y": 100}),
],
1.0,
),
# ANY_ORDER: missing tool -> fail
(
_ANY_ORDER,
[genai_types.FunctionCall(name="t1", args={})],
[
genai_types.FunctionCall(name="t1", args={}),
genai_types.FunctionCall(name="t2", args={}),
],
0.0,
),
],
ids=[
"exact_different_args_pass",
"exact_different_names_fail",
"exact_different_count_fail",
"exact_empty_lists_pass",
"in_order_different_args_pass",
"in_order_wrong_order_fail",
"in_order_missing_tool_fail",
"any_order_different_args_pass",
"any_order_missing_tool_fail",
],
)
def test_ignore_args(match_type, actual_tools, expected_tools, expected_score):
ev = _make_ignore_args_evaluator(match_type)
actual = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(tool_uses=actual_tools),
)
expected = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(tool_uses=expected_tools),
)
result = ev.evaluate_invocations([actual], [expected])
assert result.overall_score == expected_score


def test_ignore_args_false_still_checks_args():
"""Confirm ignore_args=False (default) still enforces arg matching."""
ev = TrajectoryEvaluator(
eval_metric=EvalMetric(
metric_name=PrebuiltMetrics.TOOL_TRAJECTORY_AVG_SCORE.value,
criterion=ToolTrajectoryCriterion(
threshold=0.5,
match_type=ToolTrajectoryCriterion.MatchType.EXACT,
ignore_args=False,
),
)
)
actual = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[genai_types.FunctionCall(name="t1", args={"a": 1})]
),
)
expected = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[genai_types.FunctionCall(name="t1", args={"a": 2})]
),
)
result = ev.evaluate_invocations([actual], [expected])
assert result.overall_score == 0.0


def test_ignore_args_multiple_invocations_mixed():
"""ignore_args with multiple invocations: one matches, one doesn't."""
ev = _make_ignore_args_evaluator(ToolTrajectoryCriterion.MatchType.EXACT)
inv1_actual = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[genai_types.FunctionCall(name="t1", args={"a": 1})]
),
)
inv1_expected = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[genai_types.FunctionCall(name="t1", args={"z": 99})]
),
)
inv2_actual = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[genai_types.FunctionCall(name="t1", args={})]
),
)
inv2_expected = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[genai_types.FunctionCall(name="t2", args={})]
),
)
result = ev.evaluate_invocations(
[inv1_actual, inv2_actual], [inv1_expected, inv2_expected]
)
assert result.overall_score == 0.5
assert result.per_invocation_results[0].score == 1.0
assert result.per_invocation_results[1].score == 0.0


def test_ignore_args_with_camel_case_dict_config():
"""Tests ignore_args works via camelCase key (ignoreArgs) in dict."""
eval_metric = EvalMetric(
metric_name=PrebuiltMetrics.TOOL_TRAJECTORY_AVG_SCORE.value,
criterion={
"threshold": 0.5,
"matchType": "EXACT",
"ignoreArgs": True,
},
)
ev = TrajectoryEvaluator(eval_metric=eval_metric)
actual = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[genai_types.FunctionCall(name="t1", args={"a": 1})]
),
)
expected = Invocation(
user_content=_USER_CONTENT,
intermediate_data=IntermediateData(
tool_uses=[genai_types.FunctionCall(name="t1", args={"z": 999})]
),
)
result = ev.evaluate_invocations([actual], [expected])
assert result.overall_score == 1.0