Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions core/src/main/java/com/google/adk/tools/BaseTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.fasterxml.jackson.annotation.JsonAnySetter;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.adk.JsonBaseModel;
import com.google.adk.agents.ConfigAgentUtils.ConfigurationException;
import com.google.adk.models.LlmRequest;
Expand All @@ -38,6 +39,7 @@
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import javax.annotation.Nonnull;
import org.jspecify.annotations.Nullable;
import org.slf4j.Logger;
Expand Down Expand Up @@ -93,6 +95,85 @@ public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContex
throw new UnsupportedOperationException("This method is not implemented.");
}

/**
* Calls a tool with generic arguments and returns a map of results. The args type {@code T} need
* to be serializable with {@link JsonBaseModel#getMapper()}
*/
public final <T> Single<Map<String, Object>> runAsync(T args, ToolContext toolContext) {
return runAsync(args, toolContext, JsonBaseModel.getMapper());
}

/**
* Calls a tool with generic arguments using a custom {@link ObjectMapper} and returns a map of
* results. The args type {@code T} needs to be serializable with the provided {@link
* ObjectMapper}.
*/
public final <T> Single<Map<String, Object>> runAsync(
T args, ToolContext toolContext, ObjectMapper objectMapper) {
return runAsync(args, toolContext, objectMapper, output -> output);
}

/**
* Calls a tool with generic arguments and a custom {@link ObjectMapper}, returning the results
* converted to a specified class. The input type {@code I} needs to be serializable and the
* output type {@code O} needs to be deserializable with the provided {@link ObjectMapper}.
*/
public final <I, O> Single<O> runAsync(
I args, ToolContext toolContext, ObjectMapper objectMapper, Class<? extends O> oClass) {
return runAsync(
args, toolContext, objectMapper, output -> objectMapper.convertValue(output, oClass));
}

/**
* Calls a tool with generic arguments and a custom {@link ObjectMapper}, returning the results
* converted to a specified type reference. The input type {@code I} needs to be serializable and
* the output type {@code O} needs to be deserializable with the provided {@link ObjectMapper}.
*/
public final <I, O> Single<O> runAsync(
I args,
ToolContext toolContext,
ObjectMapper objectMapper,
TypeReference<? extends O> typeReference) {
return runAsync(
args,
toolContext,
objectMapper,
output -> objectMapper.convertValue(output, typeReference));
}

/**
* Calls a tool with generic arguments, returning the results converted to a specified class. The
* input type {@code I} needs to be serializable and the output type {@code O} needs to be
* deserializable with {@link JsonBaseModel#getMapper()}
*/
public final <I, O> Single<O> runAsync(
I args, ToolContext toolContext, Class<? extends O> oClass) {
return runAsync(args, toolContext, JsonBaseModel.getMapper(), oClass);
}

/**
* Calls a tool with generic arguments, returning the results converted to a specified type
* reference. The input type needs to be serializable and the output type needs to be
* deserializable with {@link JsonBaseModel#getMapper()}
*/
public final <I, O> Single<O> runAsync(
I args, ToolContext toolContext, TypeReference<? extends O> typeReference) {
return runAsync(args, toolContext, JsonBaseModel.getMapper(), typeReference);
}

private <I, O> Single<O> runAsync(
I args,
ToolContext toolContext,
ObjectMapper objectMapper,
Function<? super Map<String, Object>, ? extends O> deserializer) {
return Single.defer(
() ->
Single.just(
objectMapper.convertValue(args, new TypeReference<Map<String, Object>>() {})))
.flatMap(argsMap -> runAsync(argsMap, toolContext))
.map(deserializer::apply);
}

/**
* Processes the outgoing {@link LlmRequest.Builder}.
*
Expand Down
108 changes: 108 additions & 0 deletions core/src/test/java/com/google/adk/tools/BaseToolTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

import static com.google.common.truth.Truth.assertThat;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.adk.agents.InvocationContext;
import com.google.adk.agents.LlmAgent;
import com.google.adk.models.Gemini;
import com.google.adk.models.LlmRequest;
import com.google.adk.sessions.InMemorySessionService;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.genai.types.FunctionDeclaration;
import com.google.genai.types.GenerateContentConfig;
import com.google.genai.types.GoogleMaps;
Expand All @@ -17,6 +20,7 @@
import com.google.genai.types.UrlContext;
import io.reactivex.rxjava3.core.Completable;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.observers.TestObserver;
import java.util.Map;
import java.util.Optional;
import org.junit.Test;
Expand All @@ -27,6 +31,20 @@
@RunWith(JUnit4.class)
public final class BaseToolTest {

private final BaseTool doublingBaseTool =
new BaseTool("doubling-test-tool", "returns doubled args") {
@Override
public Single<Map<String, Object>> runAsync(
Map<String, Object> args, ToolContext toolContext) {
String sArg = (String) args.get("s");
Integer iArg = (Integer) args.get("i");
return Single.just(
ImmutableMap.<String, Object>of(
"s", sArg + sArg,
"i", iArg + iArg));
}
};

@Test
public void processLlmRequestNoDeclarationReturnsSameRequest() {
BaseTool tool =
Expand Down Expand Up @@ -247,4 +265,94 @@ public void processLlmRequestWithGoogleMapsToolAddsToolToConfig() {
assertThat(updatedLlmRequest.config().get().tools().get())
.containsExactly(Tool.builder().googleMaps(GoogleMaps.builder().build()).build());
}

@Test
public void runAsync_withTypeReference_convertsArguments() throws Exception {
TestToolArgs testToolArgs = new TestToolArgs(42, "foo");

Single<TestToolArgs> out =
doublingBaseTool.runAsync(
testToolArgs, /* toolContext= */ null, new TypeReference<TestToolArgs>() {});
TestObserver<TestToolArgs> testObserver = out.test();

testObserver.assertComplete();
TestToolArgs expected = new TestToolArgs(84, "foofoo");
testObserver.assertValue(expected);
}

@Test
public void runAsync_withClass_convertsArguments() throws Exception {
TestToolArgs testToolArgs = new TestToolArgs(21, "bar");

Single<TestToolArgs> out =
doublingBaseTool.runAsync(testToolArgs, /* toolContext= */ null, TestToolArgs.class);
TestObserver<TestToolArgs> testObserver = out.test();

testObserver.assertComplete();
TestToolArgs expected = new TestToolArgs(42, "barbar");
testObserver.assertValue(expected);
}

@Test
public void runAsync_withObjectOnly_convertsArguments() throws Exception {
TestToolArgs testToolArgs = new TestToolArgs(11, "baz");

Single<Map<String, Object>> out =
doublingBaseTool.runAsync(testToolArgs, /* toolContext= */ null);
TestObserver<Map<String, Object>> testObserver = out.test();

testObserver.assertComplete();
ImmutableMap<String, Object> expected = ImmutableMap.of("i", 22, "s", "bazbaz");
testObserver.assertValue(expected);
}

@Test
public void runAsync_withObjectMapperAndObjectOnly_convertsArguments() throws Exception {
TestToolArgs testToolArgs = new TestToolArgs(11, "baz");
ObjectMapper objectMapper = new ObjectMapper();

Single<Map<String, Object>> out =
doublingBaseTool.runAsync(testToolArgs, /* toolContext= */ null, objectMapper);
TestObserver<Map<String, Object>> testObserver = out.test();

testObserver.assertComplete();
ImmutableMap<String, Object> expected = ImmutableMap.of("i", 22, "s", "bazbaz");
testObserver.assertValue(expected);
}

@Test
public void runAsync_withTypeReferenceAndObjectMapper_convertsArguments() throws Exception {
TestToolArgs testToolArgs = new TestToolArgs(42, "foo");
ObjectMapper objectMapper = new ObjectMapper();

Single<TestToolArgs> out =
doublingBaseTool.runAsync(
testToolArgs,
/* toolContext= */ null,
objectMapper,
new TypeReference<TestToolArgs>() {});

TestObserver<TestToolArgs> testObserver = out.test();

testObserver.assertComplete();
TestToolArgs expected = new TestToolArgs(84, "foofoo");
testObserver.assertValue(expected);
}

@Test
public void runAsync_withClassAndObjectMapper_convertsArguments() throws Exception {
TestToolArgs testToolArgs = new TestToolArgs(21, "bar");
ObjectMapper objectMapper = new ObjectMapper();

Single<TestToolArgs> out =
doublingBaseTool.runAsync(
testToolArgs, /* toolContext= */ null, objectMapper, TestToolArgs.class);
TestObserver<TestToolArgs> testObserver = out.test();

testObserver.assertComplete();
TestToolArgs expected = new TestToolArgs(42, "barbar");
testObserver.assertValue(expected);
}

public record TestToolArgs(int i, String s) {}
}
Loading