diff --git a/core/src/main/java/com/google/adk/tools/BaseTool.java b/core/src/main/java/com/google/adk/tools/BaseTool.java index 1ea2808a1..01a399920 100644 --- a/core/src/main/java/com/google/adk/tools/BaseTool.java +++ b/core/src/main/java/com/google/adk/tools/BaseTool.java @@ -22,6 +22,7 @@ import com.fasterxml.jackson.annotation.JsonAnySetter; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.JsonBaseModel; import com.google.adk.agents.ConfigAgentUtils.ConfigurationException; import com.google.adk.models.LlmRequest; @@ -38,6 +39,7 @@ import java.util.HashMap; import java.util.Map; import java.util.Optional; +import java.util.function.Function; import javax.annotation.Nonnull; import org.jspecify.annotations.Nullable; import org.slf4j.Logger; @@ -93,6 +95,85 @@ public Single> runAsync(Map args, ToolContex throw new UnsupportedOperationException("This method is not implemented."); } + /** + * Calls a tool with generic arguments and returns a map of results. The args type {@code T} need + * to be serializable with {@link JsonBaseModel#getMapper()} + */ + public final Single> runAsync(T args, ToolContext toolContext) { + return runAsync(args, toolContext, JsonBaseModel.getMapper()); + } + + /** + * Calls a tool with generic arguments using a custom {@link ObjectMapper} and returns a map of + * results. The args type {@code T} needs to be serializable with the provided {@link + * ObjectMapper}. + */ + public final Single> runAsync( + T args, ToolContext toolContext, ObjectMapper objectMapper) { + return runAsync(args, toolContext, objectMapper, output -> output); + } + + /** + * Calls a tool with generic arguments and a custom {@link ObjectMapper}, returning the results + * converted to a specified class. The input type {@code I} needs to be serializable and the + * output type {@code O} needs to be deserializable with the provided {@link ObjectMapper}. + */ + public final Single runAsync( + I args, ToolContext toolContext, ObjectMapper objectMapper, Class oClass) { + return runAsync( + args, toolContext, objectMapper, output -> objectMapper.convertValue(output, oClass)); + } + + /** + * Calls a tool with generic arguments and a custom {@link ObjectMapper}, returning the results + * converted to a specified type reference. The input type {@code I} needs to be serializable and + * the output type {@code O} needs to be deserializable with the provided {@link ObjectMapper}. + */ + public final Single runAsync( + I args, + ToolContext toolContext, + ObjectMapper objectMapper, + TypeReference typeReference) { + return runAsync( + args, + toolContext, + objectMapper, + output -> objectMapper.convertValue(output, typeReference)); + } + + /** + * Calls a tool with generic arguments, returning the results converted to a specified class. The + * input type {@code I} needs to be serializable and the output type {@code O} needs to be + * deserializable with {@link JsonBaseModel#getMapper()} + */ + public final Single runAsync( + I args, ToolContext toolContext, Class oClass) { + return runAsync(args, toolContext, JsonBaseModel.getMapper(), oClass); + } + + /** + * Calls a tool with generic arguments, returning the results converted to a specified type + * reference. The input type needs to be serializable and the output type needs to be + * deserializable with {@link JsonBaseModel#getMapper()} + */ + public final Single runAsync( + I args, ToolContext toolContext, TypeReference typeReference) { + return runAsync(args, toolContext, JsonBaseModel.getMapper(), typeReference); + } + + private Single runAsync( + I args, + ToolContext toolContext, + ObjectMapper objectMapper, + Function, ? extends O> deserializer) { + return Single.defer( + () -> + Single.just( + objectMapper.convertValue(args, new TypeReference>() {}))) + .flatMap(argsMap -> runAsync(argsMap, toolContext)) + .map(deserializer::apply); + } + /** * Processes the outgoing {@link LlmRequest.Builder}. * diff --git a/core/src/test/java/com/google/adk/tools/BaseToolTest.java b/core/src/test/java/com/google/adk/tools/BaseToolTest.java index 2a07e7a44..d3c8da5aa 100644 --- a/core/src/test/java/com/google/adk/tools/BaseToolTest.java +++ b/core/src/test/java/com/google/adk/tools/BaseToolTest.java @@ -2,12 +2,15 @@ import static com.google.common.truth.Truth.assertThat; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LlmAgent; import com.google.adk.models.Gemini; import com.google.adk.models.LlmRequest; import com.google.adk.sessions.InMemorySessionService; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.GenerateContentConfig; import com.google.genai.types.GoogleMaps; @@ -17,6 +20,7 @@ import com.google.genai.types.UrlContext; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.observers.TestObserver; import java.util.Map; import java.util.Optional; import org.junit.Test; @@ -27,6 +31,20 @@ @RunWith(JUnit4.class) public final class BaseToolTest { + private final BaseTool doublingBaseTool = + new BaseTool("doubling-test-tool", "returns doubled args") { + @Override + public Single> runAsync( + Map args, ToolContext toolContext) { + String sArg = (String) args.get("s"); + Integer iArg = (Integer) args.get("i"); + return Single.just( + ImmutableMap.of( + "s", sArg + sArg, + "i", iArg + iArg)); + } + }; + @Test public void processLlmRequestNoDeclarationReturnsSameRequest() { BaseTool tool = @@ -247,4 +265,94 @@ public void processLlmRequestWithGoogleMapsToolAddsToolToConfig() { assertThat(updatedLlmRequest.config().get().tools().get()) .containsExactly(Tool.builder().googleMaps(GoogleMaps.builder().build()).build()); } + + @Test + public void runAsync_withTypeReference_convertsArguments() throws Exception { + TestToolArgs testToolArgs = new TestToolArgs(42, "foo"); + + Single out = + doublingBaseTool.runAsync( + testToolArgs, /* toolContext= */ null, new TypeReference() {}); + TestObserver testObserver = out.test(); + + testObserver.assertComplete(); + TestToolArgs expected = new TestToolArgs(84, "foofoo"); + testObserver.assertValue(expected); + } + + @Test + public void runAsync_withClass_convertsArguments() throws Exception { + TestToolArgs testToolArgs = new TestToolArgs(21, "bar"); + + Single out = + doublingBaseTool.runAsync(testToolArgs, /* toolContext= */ null, TestToolArgs.class); + TestObserver testObserver = out.test(); + + testObserver.assertComplete(); + TestToolArgs expected = new TestToolArgs(42, "barbar"); + testObserver.assertValue(expected); + } + + @Test + public void runAsync_withObjectOnly_convertsArguments() throws Exception { + TestToolArgs testToolArgs = new TestToolArgs(11, "baz"); + + Single> out = + doublingBaseTool.runAsync(testToolArgs, /* toolContext= */ null); + TestObserver> testObserver = out.test(); + + testObserver.assertComplete(); + ImmutableMap expected = ImmutableMap.of("i", 22, "s", "bazbaz"); + testObserver.assertValue(expected); + } + + @Test + public void runAsync_withObjectMapperAndObjectOnly_convertsArguments() throws Exception { + TestToolArgs testToolArgs = new TestToolArgs(11, "baz"); + ObjectMapper objectMapper = new ObjectMapper(); + + Single> out = + doublingBaseTool.runAsync(testToolArgs, /* toolContext= */ null, objectMapper); + TestObserver> testObserver = out.test(); + + testObserver.assertComplete(); + ImmutableMap expected = ImmutableMap.of("i", 22, "s", "bazbaz"); + testObserver.assertValue(expected); + } + + @Test + public void runAsync_withTypeReferenceAndObjectMapper_convertsArguments() throws Exception { + TestToolArgs testToolArgs = new TestToolArgs(42, "foo"); + ObjectMapper objectMapper = new ObjectMapper(); + + Single out = + doublingBaseTool.runAsync( + testToolArgs, + /* toolContext= */ null, + objectMapper, + new TypeReference() {}); + + TestObserver testObserver = out.test(); + + testObserver.assertComplete(); + TestToolArgs expected = new TestToolArgs(84, "foofoo"); + testObserver.assertValue(expected); + } + + @Test + public void runAsync_withClassAndObjectMapper_convertsArguments() throws Exception { + TestToolArgs testToolArgs = new TestToolArgs(21, "bar"); + ObjectMapper objectMapper = new ObjectMapper(); + + Single out = + doublingBaseTool.runAsync( + testToolArgs, /* toolContext= */ null, objectMapper, TestToolArgs.class); + TestObserver testObserver = out.test(); + + testObserver.assertComplete(); + TestToolArgs expected = new TestToolArgs(42, "barbar"); + testObserver.assertValue(expected); + } + + public record TestToolArgs(int i, String s) {} }