From 91a434441f4d4eaf715d7cbeeeb5c525b25d4fdb Mon Sep 17 00:00:00 2001 From: Gautam Datla Date: Tue, 10 Mar 2026 03:23:44 -0400 Subject: [PATCH] fix: populate required for Pydantic BaseModel parameter schemas --- .../tools/_function_parameter_parse_util.py | 7 +++ .../tools/test_build_function_declaration.py | 59 +++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/src/google/adk/tools/_function_parameter_parse_util.py b/src/google/adk/tools/_function_parameter_parse_util.py index 1b9559b29c..03c8ff9d0c 100644 --- a/src/google/adk/tools/_function_parameter_parse_util.py +++ b/src/google/adk/tools/_function_parameter_parse_util.py @@ -399,6 +399,13 @@ def _parse_schema_from_parameter( ), func_name, ) + required_fields = [ + field_name + for field_name, field_info in param.annotation.model_fields.items() + if field_info.is_required() + ] + if required_fields: + schema.required = required_fields _raise_if_schema_unsupported(variant, schema) return schema if inspect.isclass(param.annotation) and issubclass( diff --git a/tests/unittests/tools/test_build_function_declaration.py b/tests/unittests/tools/test_build_function_declaration.py index 1c9bf245f1..608f06cb13 100644 --- a/tests/unittests/tools/test_build_function_declaration.py +++ b/tests/unittests/tools/test_build_function_declaration.py @@ -124,6 +124,65 @@ def simple_function(input: CustomInput) -> str: ) +def test_basemodel_required_fields(): + class SearchRequest(BaseModel): + query: str + max_results: int + filter: str = '' + + def search(request: SearchRequest) -> list: + return [] + + function_decl = _automatic_function_calling_util.build_function_declaration( + func=search + ) + + inner = function_decl.parameters.properties['request'] + assert set(inner.required) == {'query', 'max_results'} + assert 'filter' not in (inner.required or []) + + +def test_basemodel_all_optional_fields_no_required(): + class Config(BaseModel): + timeout: int = 30 + retries: int = 3 + + def run(config: Config) -> str: + return '' + + function_decl = _automatic_function_calling_util.build_function_declaration( + func=run + ) + + inner = function_decl.parameters.properties['config'] + assert not inner.required + + +def test_nested_basemodel_required_fields(): + class Inner(BaseModel): + x: int + y: int = 0 + + class Outer(BaseModel): + inner: Inner + label: str = '' + + def process(data: Outer) -> str: + return '' + + function_decl = _automatic_function_calling_util.build_function_declaration( + func=process + ) + + outer = function_decl.parameters.properties['data'] + assert set(outer.required) == {'inner'} + assert 'label' not in (outer.required or []) + + inner = outer.properties['inner'] + assert set(inner.required) == {'x'} + assert 'y' not in (inner.required or []) + + def test_toolcontext_ignored(): def simple_function(input_str: str, tool_context: ToolContext) -> str: return {'result': input_str}