diff --git a/src/google/adk/evaluation/eval_metrics.py b/src/google/adk/evaluation/eval_metrics.py index eb7c7e36cb..8a2821b905 100644 --- a/src/google/adk/evaluation/eval_metrics.py +++ b/src/google/adk/evaluation/eval_metrics.py @@ -226,6 +226,13 @@ class MatchType(Enum): ), ) + ignore_args: bool = Field( + default=False, + description=( + "If True, only tool names are compared; arguments are ignored." + ), + ) + @field_validator("match_type", mode="before") @classmethod def _coerce_match_type(cls, value: object) -> object: diff --git a/src/google/adk/evaluation/trajectory_evaluator.py b/src/google/adk/evaluation/trajectory_evaluator.py index 07626d7687..25720c4a52 100644 --- a/src/google/adk/evaluation/trajectory_evaluator.py +++ b/src/google/adk/evaluation/trajectory_evaluator.py @@ -82,18 +82,17 @@ def __init__( ) self._threshold = criterion.threshold self._match_type = criterion.match_type + self._ignore_args = criterion.ignore_args except ValidationError as e: expected_criterion_type_error = ValueError( f"`{eval_metric.metric_name}` metric expects a criterion of type" f" `{TrajectoryEvaluator.criterion_type}`." ) raise expected_criterion_type_error from e - elif eval_metric: - self._threshold = eval_metric.threshold - self._match_type = ToolTrajectoryCriterion.MatchType.EXACT else: - self._threshold = threshold + self._threshold = eval_metric.threshold if eval_metric else threshold self._match_type = ToolTrajectoryCriterion.MatchType.EXACT + self._ignore_args = False @override def evaluate_invocations( @@ -191,9 +190,8 @@ def _are_tool_calls_in_order_match( try: current_expected = next(expected_it) for actual in actual_tool_calls: - if ( - actual.name == current_expected.name - and actual.args == current_expected.args + if actual.name == current_expected.name and ( + self._ignore_args or actual.args == current_expected.args ): current_expected = next(expected_it) except StopIteration: @@ -229,7 +227,9 @@ def _are_tool_calls_any_order_match( for expected in expected_tool_calls: found = False for i, actual in enumerate(actual_tool_calls_copy): - if actual.name == expected.name and actual.args == expected.args: + if actual.name == expected.name and ( + self._ignore_args or actual.args == expected.args + ): actual_tool_calls_copy.pop(i) found = True break @@ -260,7 +260,9 @@ def _are_tool_calls_exact_match( return False for actual, expected in zip(actual_tool_calls, expected_tool_calls): - if actual.name != expected.name or actual.args != expected.args: + if actual.name != expected.name or ( + not self._ignore_args and actual.args != expected.args + ): return False return True diff --git a/tests/unittests/evaluation/test_trajectory_evaluator.py b/tests/unittests/evaluation/test_trajectory_evaluator.py index 0fa3fa5a73..59e1b96b8c 100644 --- a/tests/unittests/evaluation/test_trajectory_evaluator.py +++ b/tests/unittests/evaluation/test_trajectory_evaluator.py @@ -462,3 +462,243 @@ def test_evaluate_invocations_no_invocations(evaluator: TrajectoryEvaluator): assert result.overall_score is None assert result.overall_eval_status == EvalStatus.NOT_EVALUATED assert not result.per_invocation_results + + +# --- ignore_args tests --- + + +def _make_ignore_args_evaluator( + match_type: ToolTrajectoryCriterion.MatchType, +) -> TrajectoryEvaluator: + return TrajectoryEvaluator( + eval_metric=EvalMetric( + metric_name=PrebuiltMetrics.TOOL_TRAJECTORY_AVG_SCORE.value, + criterion=ToolTrajectoryCriterion( + threshold=0.5, + match_type=match_type, + ignore_args=True, + ), + ) + ) + + +_EXACT = ToolTrajectoryCriterion.MatchType.EXACT +_IN_ORDER = ToolTrajectoryCriterion.MatchType.IN_ORDER +_ANY_ORDER = ToolTrajectoryCriterion.MatchType.ANY_ORDER + + +@pytest.mark.parametrize( + ("match_type", "actual_tools", "expected_tools", "expected_score"), + [ + # EXACT: different args, same names -> pass + ( + _EXACT, + [ + genai_types.FunctionCall(name="t1", args={"a": 1}), + genai_types.FunctionCall(name="t2", args={"b": 2}), + ], + [ + genai_types.FunctionCall(name="t1", args={"x": 99}), + genai_types.FunctionCall(name="t2", args={"y": 100}), + ], + 1.0, + ), + # EXACT: different names -> fail + ( + _EXACT, + [genai_types.FunctionCall(name="t1", args={})], + [genai_types.FunctionCall(name="t2", args={})], + 0.0, + ), + # EXACT: different tool count -> fail + ( + _EXACT, + [ + genai_types.FunctionCall(name="t1", args={}), + genai_types.FunctionCall(name="t2", args={}), + ], + [genai_types.FunctionCall(name="t1", args={})], + 0.0, + ), + # EXACT: empty lists -> pass + ( + _EXACT, + [], + [], + 1.0, + ), + # IN_ORDER: different args with extra tools -> pass + ( + _IN_ORDER, + [ + genai_types.FunctionCall(name="t1", args={"a": 1}), + genai_types.FunctionCall(name="extra", args={}), + genai_types.FunctionCall(name="t2", args={"b": 2}), + ], + [ + genai_types.FunctionCall(name="t1", args={"x": 99}), + genai_types.FunctionCall(name="t2", args={"y": 100}), + ], + 1.0, + ), + # IN_ORDER: wrong order -> fail + ( + _IN_ORDER, + [ + genai_types.FunctionCall(name="t2", args={}), + genai_types.FunctionCall(name="t1", args={}), + ], + [ + genai_types.FunctionCall(name="t1", args={}), + genai_types.FunctionCall(name="t2", args={}), + ], + 0.0, + ), + # IN_ORDER: missing tool -> fail + ( + _IN_ORDER, + [genai_types.FunctionCall(name="t1", args={})], + [ + genai_types.FunctionCall(name="t1", args={}), + genai_types.FunctionCall(name="t2", args={}), + ], + 0.0, + ), + # ANY_ORDER: different args, swapped order -> pass + ( + _ANY_ORDER, + [ + genai_types.FunctionCall(name="t2", args={"b": 2}), + genai_types.FunctionCall(name="t1", args={"a": 1}), + ], + [ + genai_types.FunctionCall(name="t1", args={"x": 99}), + genai_types.FunctionCall(name="t2", args={"y": 100}), + ], + 1.0, + ), + # ANY_ORDER: missing tool -> fail + ( + _ANY_ORDER, + [genai_types.FunctionCall(name="t1", args={})], + [ + genai_types.FunctionCall(name="t1", args={}), + genai_types.FunctionCall(name="t2", args={}), + ], + 0.0, + ), + ], + ids=[ + "exact_different_args_pass", + "exact_different_names_fail", + "exact_different_count_fail", + "exact_empty_lists_pass", + "in_order_different_args_pass", + "in_order_wrong_order_fail", + "in_order_missing_tool_fail", + "any_order_different_args_pass", + "any_order_missing_tool_fail", + ], +) +def test_ignore_args(match_type, actual_tools, expected_tools, expected_score): + ev = _make_ignore_args_evaluator(match_type) + actual = Invocation( + user_content=_USER_CONTENT, + intermediate_data=IntermediateData(tool_uses=actual_tools), + ) + expected = Invocation( + user_content=_USER_CONTENT, + intermediate_data=IntermediateData(tool_uses=expected_tools), + ) + result = ev.evaluate_invocations([actual], [expected]) + assert result.overall_score == expected_score + + +def test_ignore_args_false_still_checks_args(): + """Confirm ignore_args=False (default) still enforces arg matching.""" + ev = TrajectoryEvaluator( + eval_metric=EvalMetric( + metric_name=PrebuiltMetrics.TOOL_TRAJECTORY_AVG_SCORE.value, + criterion=ToolTrajectoryCriterion( + threshold=0.5, + match_type=ToolTrajectoryCriterion.MatchType.EXACT, + ignore_args=False, + ), + ) + ) + actual = Invocation( + user_content=_USER_CONTENT, + intermediate_data=IntermediateData( + tool_uses=[genai_types.FunctionCall(name="t1", args={"a": 1})] + ), + ) + expected = Invocation( + user_content=_USER_CONTENT, + intermediate_data=IntermediateData( + tool_uses=[genai_types.FunctionCall(name="t1", args={"a": 2})] + ), + ) + result = ev.evaluate_invocations([actual], [expected]) + assert result.overall_score == 0.0 + + +def test_ignore_args_multiple_invocations_mixed(): + """ignore_args with multiple invocations: one matches, one doesn't.""" + ev = _make_ignore_args_evaluator(ToolTrajectoryCriterion.MatchType.EXACT) + inv1_actual = Invocation( + user_content=_USER_CONTENT, + intermediate_data=IntermediateData( + tool_uses=[genai_types.FunctionCall(name="t1", args={"a": 1})] + ), + ) + inv1_expected = Invocation( + user_content=_USER_CONTENT, + intermediate_data=IntermediateData( + tool_uses=[genai_types.FunctionCall(name="t1", args={"z": 99})] + ), + ) + inv2_actual = Invocation( + user_content=_USER_CONTENT, + intermediate_data=IntermediateData( + tool_uses=[genai_types.FunctionCall(name="t1", args={})] + ), + ) + inv2_expected = Invocation( + user_content=_USER_CONTENT, + intermediate_data=IntermediateData( + tool_uses=[genai_types.FunctionCall(name="t2", args={})] + ), + ) + result = ev.evaluate_invocations( + [inv1_actual, inv2_actual], [inv1_expected, inv2_expected] + ) + assert result.overall_score == 0.5 + assert result.per_invocation_results[0].score == 1.0 + assert result.per_invocation_results[1].score == 0.0 + + +def test_ignore_args_with_camel_case_dict_config(): + """Tests ignore_args works via camelCase key (ignoreArgs) in dict.""" + eval_metric = EvalMetric( + metric_name=PrebuiltMetrics.TOOL_TRAJECTORY_AVG_SCORE.value, + criterion={ + "threshold": 0.5, + "matchType": "EXACT", + "ignoreArgs": True, + }, + ) + ev = TrajectoryEvaluator(eval_metric=eval_metric) + actual = Invocation( + user_content=_USER_CONTENT, + intermediate_data=IntermediateData( + tool_uses=[genai_types.FunctionCall(name="t1", args={"a": 1})] + ), + ) + expected = Invocation( + user_content=_USER_CONTENT, + intermediate_data=IntermediateData( + tool_uses=[genai_types.FunctionCall(name="t1", args={"z": 999})] + ), + ) + result = ev.evaluate_invocations([actual], [expected]) + assert result.overall_score == 1.0