Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 58 additions & 19 deletions core/src/main/java/com/google/adk/agents/LlmAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import static java.util.stream.Collectors.joining;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.adk.SchemaUtils;
import com.google.adk.agents.Callbacks.AfterAgentCallbackSync;
import com.google.adk.agents.Callbacks.AfterModelCallback;
Expand Down Expand Up @@ -50,12 +51,15 @@
import com.google.adk.flows.llmflows.SingleFlow;
import com.google.adk.models.BaseLlm;
import com.google.adk.models.LlmRegistry;
import com.google.adk.models.LlmResponse;
import com.google.adk.models.Model;
import com.google.adk.tools.BaseTool;
import com.google.adk.tools.BaseToolset;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.google.genai.types.Blob;
import com.google.genai.types.Content;
import com.google.genai.types.GenerateContentConfig;
import com.google.genai.types.Part;
Expand All @@ -64,6 +68,7 @@
import io.reactivex.rxjava3.core.Flowable;
import io.reactivex.rxjava3.core.Maybe;
import io.reactivex.rxjava3.core.Single;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -134,7 +139,13 @@ protected LlmAgent(Builder builder) {
this.disallowTransferToParent = requireNonNullElse(builder.disallowTransferToParent, false);
this.disallowTransferToPeers = requireNonNullElse(builder.disallowTransferToPeers, false);
this.beforeModelCallback = requireNonNullElse(builder.beforeModelCallback, ImmutableList.of());
this.afterModelCallback = requireNonNullElse(builder.afterModelCallback, ImmutableList.of());
List<AfterModelCallback> afterCallbacks = new ArrayList<>();
if (builder.outputKey != null) {
afterCallbacks.add(
new OutputKeySaverCallback(builder.outputKey, Optional.ofNullable(builder.outputSchema)));
}
afterCallbacks.addAll(requireNonNullElse(builder.afterModelCallback, ImmutableList.of()));
this.afterModelCallback = ImmutableList.copyOf(afterCallbacks);
this.onModelErrorCallback =
requireNonNullElse(builder.onModelErrorCallback, ImmutableList.of());
this.beforeToolCallback = requireNonNullElse(builder.beforeToolCallback, ImmutableList.of());
Expand Down Expand Up @@ -610,41 +621,69 @@ protected BaseLlmFlow determineLlmFlow() {
}
}

private void maybeSaveOutputToState(Event event) {
if (outputKey().isPresent() && event.finalResponse() && event.content().isPresent()) {
// Concatenate text from all parts, excluding thoughts.
Object output;
private static class OutputKeySaverCallback implements AfterModelCallback {
private static final ObjectMapper objectMapper = new ObjectMapper();
private final String outputKey;
private final Optional<Schema> outputSchema;

private OutputKeySaverCallback(String outputKey, Optional<Schema> outputSchema) {
this.outputKey = outputKey;
this.outputSchema = outputSchema;
}

@Override
public Maybe<LlmResponse> call(CallbackContext context, LlmResponse response) {
if (response.content().isEmpty()) {
return Maybe.empty();
}

Content originalContent = response.content().get();
String rawResult =
event.content().flatMap(Content::parts).orElseGet(ImmutableList::of).stream()
originalContent.parts().orElse(ImmutableList.of()).stream()
.filter(part -> !isThought(part))
.map(part -> part.text().orElse(""))
.collect(joining());

Optional<Schema> outputSchema = outputSchema();
Object output;
if (outputSchema.isPresent()) {
try {
Map<String, Object> validatedMap =
SchemaUtils.validateOutputSchema(rawResult, outputSchema.get());
output = validatedMap;
} catch (JsonProcessingException e) {
logger.error(
"LlmAgent output for outputKey '{}' was not valid JSON, despite an outputSchema being"
+ " present. Saving raw output to state.",
outputKey().get(),
e);
output = rawResult;
} catch (IllegalArgumentException e) {
} catch (JsonProcessingException | IllegalArgumentException e) {
logger.error(
"LlmAgent output for outputKey '{}' did not match the outputSchema. Saving raw output"
+ " to state.",
outputKey().get(),
outputKey,
e);
output = rawResult;
}
} else {
output = rawResult;
}
event.actions().stateDelta().put(outputKey().get(), output);

String jsonMetadata;
try {
jsonMetadata = objectMapper.writeValueAsString(ImmutableMap.of(outputKey, output));
} catch (JsonProcessingException e) {
return Maybe.error(e);
}

Part stateDeltaPart =
Part.builder()
.inlineData(
Blob.builder()
.data(jsonMetadata.getBytes(StandardCharsets.UTF_8))
.mimeType("application/json+metadata")
.build())
.build();

List<Part> newParts = new ArrayList<>(originalContent.parts().orElse(ImmutableList.of()));
newParts.add(stateDeltaPart);

Content newContent = originalContent.toBuilder().parts(newParts).build();

return Maybe.just(response.toBuilder().content(newContent).build());
}
}

Expand All @@ -654,12 +693,12 @@ private static boolean isThought(Part part) {

@Override
protected Flowable<Event> runAsyncImpl(InvocationContext invocationContext) {
return llmFlow.run(invocationContext).doOnNext(this::maybeSaveOutputToState);
return llmFlow.run(invocationContext);
}

@Override
protected Flowable<Event> runLiveImpl(InvocationContext invocationContext) {
return llmFlow.runLive(invocationContext).doOnNext(this::maybeSaveOutputToState);
return llmFlow.runLive(invocationContext);
}

/**
Expand Down
124 changes: 108 additions & 16 deletions core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@

package com.google.adk.flows.llmflows;

import static java.nio.charset.StandardCharsets.UTF_8;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.adk.agents.ActiveStreamingTool;
import com.google.adk.agents.BaseAgent;
import com.google.adk.agents.CallbackContext;
Expand All @@ -28,6 +33,7 @@
import com.google.adk.agents.ReadonlyContext;
import com.google.adk.agents.RunConfig.StreamingMode;
import com.google.adk.events.Event;
import com.google.adk.events.EventActions;
import com.google.adk.flows.BaseFlow;
import com.google.adk.flows.llmflows.RequestProcessor.RequestProcessingResult;
import com.google.adk.flows.llmflows.ResponseProcessor.ResponseProcessingResult;
Expand All @@ -41,7 +47,9 @@
import com.google.adk.tools.ToolContext;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.genai.types.Content;
import com.google.genai.types.FunctionResponse;
import com.google.genai.types.Part;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.api.trace.StatusCode;
import io.opentelemetry.context.Context;
Expand All @@ -54,7 +62,9 @@
import io.reactivex.rxjava3.observers.DisposableCompletableObserver;
import io.reactivex.rxjava3.schedulers.Schedulers;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
Expand All @@ -64,6 +74,7 @@
/** A basic flow that calls the LLM in a loop until a final response is generated. */
public abstract class BaseLlmFlow implements BaseFlow {
private static final Logger logger = LoggerFactory.getLogger(BaseLlmFlow.class);
private static final ObjectMapper objectMapper = new ObjectMapper();

protected final List<RequestProcessor> requestProcessors;
protected final List<ResponseProcessor> responseProcessors;
Expand Down Expand Up @@ -349,14 +360,19 @@ private Single<LlmResponse> handleAfterModelCallback(

Maybe<LlmResponse> callbackResult =
Maybe.defer(
() ->
Flowable.fromIterable(callbacks)
.concatMapMaybe(
callback ->
() -> {
Single<LlmResponse> currentResponse = Single.just(llmResponse);
for (AfterModelCallback callback : callbacks) {
currentResponse =
currentResponse.flatMap(
resp ->
callback
.call(callbackContext, llmResponse)
.compose(Tracing.withContext(currentContext)))
.firstElement());
.call(callbackContext, resp)
.compose(Tracing.withContext(currentContext))
.defaultIfEmpty(resp));
}
return currentResponse.toMaybe();
});

return pluginResult.switchIfEmpty(callbackResult).defaultIfEmpty(llmResponse);
}
Expand Down Expand Up @@ -461,14 +477,37 @@ public Flowable<Event> run(InvocationContext invocationContext) {

private Flowable<Event> run(
Context spanContext, InvocationContext invocationContext, int stepsCompleted) {
Flowable<Event> currentStepEvents = runOneStep(spanContext, invocationContext).cache();
Flowable<Event> currentStepEvents = runOneStep(spanContext, invocationContext);

Flowable<Event> processedEvents =
currentStepEvents
.concatMap(
event -> {
if (invocationContext.session().events().stream()
.anyMatch(e -> e.id() != null && e.id().equals(event.id()))) {
logger.debug("Event {} already in session, skipping append", event.id());
return Flowable.just(event);
}
return invocationContext
.sessionService()
.appendEvent(invocationContext.session(), event)
.flatMap(
registeredEvent ->
invocationContext
.pluginManager()
.onEventCallback(invocationContext, registeredEvent)
.defaultIfEmpty(registeredEvent))
.toFlowable();
})
.cache();

if (stepsCompleted + 1 >= maxSteps) {
logger.debug("Ending flow execution because max steps reached.");
return currentStepEvents;
return processedEvents;
}

return currentStepEvents.concatWith(
currentStepEvents
return processedEvents.concatWith(
processedEvents
.toList()
.flatMapPublisher(
eventList -> {
Expand Down Expand Up @@ -685,22 +724,75 @@ private Flowable<Event> buildPostprocessingEvents(

Event modelResponseEvent =
buildModelResponseEvent(baseEventForLlmResponse, llmRequest, updatedResponse);
if (modelResponseEvent.functionCalls().isEmpty()) {
return processorEvents.concatWith(Flowable.just(modelResponseEvent));

if (context.agent() instanceof LlmAgent agent) {
Optional<String> outputKeyOpt = agent.outputKey();
if (outputKeyOpt.isPresent() && modelResponseEvent.content().isPresent()) {
Content content = modelResponseEvent.content().get();
Map<String, Object> extractedDelta = new HashMap<>();
List<Part> cleanParts = new ArrayList<>();
boolean metadataFound = false;
for (Part part : content.parts().orElse(ImmutableList.of())) {
if (part.inlineData().isPresent()
&& part.inlineData()
.get()
.mimeType()
.orElse("")
.equals("application/json+metadata")) {
metadataFound = true;
byte[] data = part.inlineData().get().data().orElse(null);
if (data != null) {
String json = new String(data, UTF_8);
try {
Map<String, Object> metadata =
objectMapper.readValue(json, new TypeReference<Map<String, Object>>() {});
extractedDelta.putAll(metadata);
} catch (JsonProcessingException e) {
logger.error("Failed to parse metadata from inlineData", e);
}
}
} else {
cleanParts.add(part);
}
}

if (metadataFound) {
Event.Builder updatedEventBuilder = modelResponseEvent.toBuilder();
Content newContent =
Content.builder().role(content.role().orElse("model")).parts(cleanParts).build();
updatedEventBuilder.content(newContent);

if (!extractedDelta.isEmpty() && modelResponseEvent.finalResponse()) {
Map<String, Object> newStateDelta =
new HashMap<>(modelResponseEvent.actions().stateDelta());
newStateDelta.putAll(extractedDelta);
EventActions updatedActions =
modelResponseEvent.actions().toBuilder().stateDelta(newStateDelta).build();
updatedEventBuilder.actions(updatedActions);
}
modelResponseEvent = updatedEventBuilder.build();
}
}
}
final Event finalModelResponseEvent = modelResponseEvent;

if (finalModelResponseEvent.functionCalls().isEmpty()) {
return processorEvents.concatWith(Flowable.just(finalModelResponseEvent));
}

Flowable<Event> functionEvents;
try (Scope scope = parentContext.makeCurrent()) {
Maybe<Event> maybeFunctionResponseEvent =
context.runConfig().streamingMode() == StreamingMode.BIDI
? Functions.handleFunctionCallsLive(context, modelResponseEvent, llmRequest.tools())
: Functions.handleFunctionCalls(context, modelResponseEvent, llmRequest.tools());
? Functions.handleFunctionCallsLive(
context, finalModelResponseEvent, llmRequest.tools())
: Functions.handleFunctionCalls(context, finalModelResponseEvent, llmRequest.tools());
functionEvents =
maybeFunctionResponseEvent.flatMapPublisher(
functionResponseEvent -> {
Optional<Event> toolConfirmationEvent =
Functions.generateRequestConfirmationEvent(
context, modelResponseEvent, functionResponseEvent);
context, finalModelResponseEvent, functionResponseEvent);
List<Event> events = new ArrayList<>();
toolConfirmationEvent.ifPresent(events::add);
events.add(functionResponseEvent);
Expand Down
Loading