diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index 98bba4606..dc8b6d58b 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -21,6 +21,7 @@ import static java.util.stream.Collectors.joining; import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.SchemaUtils; import com.google.adk.agents.Callbacks.AfterAgentCallbackSync; import com.google.adk.agents.Callbacks.AfterModelCallback; @@ -50,12 +51,15 @@ import com.google.adk.flows.llmflows.SingleFlow; import com.google.adk.models.BaseLlm; import com.google.adk.models.LlmRegistry; +import com.google.adk.models.LlmResponse; import com.google.adk.models.Model; import com.google.adk.tools.BaseTool; import com.google.adk.tools.BaseToolset; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.genai.types.Blob; import com.google.genai.types.Content; import com.google.genai.types.GenerateContentConfig; import com.google.genai.types.Part; @@ -64,6 +68,7 @@ import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -134,7 +139,13 @@ protected LlmAgent(Builder builder) { this.disallowTransferToParent = requireNonNullElse(builder.disallowTransferToParent, false); this.disallowTransferToPeers = requireNonNullElse(builder.disallowTransferToPeers, false); this.beforeModelCallback = requireNonNullElse(builder.beforeModelCallback, ImmutableList.of()); - this.afterModelCallback = requireNonNullElse(builder.afterModelCallback, ImmutableList.of()); + List afterCallbacks = new ArrayList<>(); + if (builder.outputKey != null) { + afterCallbacks.add( + new OutputKeySaverCallback(builder.outputKey, Optional.ofNullable(builder.outputSchema))); + } + afterCallbacks.addAll(requireNonNullElse(builder.afterModelCallback, ImmutableList.of())); + this.afterModelCallback = ImmutableList.copyOf(afterCallbacks); this.onModelErrorCallback = requireNonNullElse(builder.onModelErrorCallback, ImmutableList.of()); this.beforeToolCallback = requireNonNullElse(builder.beforeToolCallback, ImmutableList.of()); @@ -610,41 +621,69 @@ protected BaseLlmFlow determineLlmFlow() { } } - private void maybeSaveOutputToState(Event event) { - if (outputKey().isPresent() && event.finalResponse() && event.content().isPresent()) { - // Concatenate text from all parts, excluding thoughts. - Object output; + private static class OutputKeySaverCallback implements AfterModelCallback { + private static final ObjectMapper objectMapper = new ObjectMapper(); + private final String outputKey; + private final Optional outputSchema; + + private OutputKeySaverCallback(String outputKey, Optional outputSchema) { + this.outputKey = outputKey; + this.outputSchema = outputSchema; + } + + @Override + public Maybe call(CallbackContext context, LlmResponse response) { + if (response.content().isEmpty()) { + return Maybe.empty(); + } + + Content originalContent = response.content().get(); String rawResult = - event.content().flatMap(Content::parts).orElseGet(ImmutableList::of).stream() + originalContent.parts().orElse(ImmutableList.of()).stream() .filter(part -> !isThought(part)) .map(part -> part.text().orElse("")) .collect(joining()); - Optional outputSchema = outputSchema(); + Object output; if (outputSchema.isPresent()) { try { Map validatedMap = SchemaUtils.validateOutputSchema(rawResult, outputSchema.get()); output = validatedMap; - } catch (JsonProcessingException e) { - logger.error( - "LlmAgent output for outputKey '{}' was not valid JSON, despite an outputSchema being" - + " present. Saving raw output to state.", - outputKey().get(), - e); - output = rawResult; - } catch (IllegalArgumentException e) { + } catch (JsonProcessingException | IllegalArgumentException e) { logger.error( "LlmAgent output for outputKey '{}' did not match the outputSchema. Saving raw output" + " to state.", - outputKey().get(), + outputKey, e); output = rawResult; } } else { output = rawResult; } - event.actions().stateDelta().put(outputKey().get(), output); + + String jsonMetadata; + try { + jsonMetadata = objectMapper.writeValueAsString(ImmutableMap.of(outputKey, output)); + } catch (JsonProcessingException e) { + return Maybe.error(e); + } + + Part stateDeltaPart = + Part.builder() + .inlineData( + Blob.builder() + .data(jsonMetadata.getBytes(StandardCharsets.UTF_8)) + .mimeType("application/json+metadata") + .build()) + .build(); + + List newParts = new ArrayList<>(originalContent.parts().orElse(ImmutableList.of())); + newParts.add(stateDeltaPart); + + Content newContent = originalContent.toBuilder().parts(newParts).build(); + + return Maybe.just(response.toBuilder().content(newContent).build()); } } @@ -654,12 +693,12 @@ private static boolean isThought(Part part) { @Override protected Flowable runAsyncImpl(InvocationContext invocationContext) { - return llmFlow.run(invocationContext).doOnNext(this::maybeSaveOutputToState); + return llmFlow.run(invocationContext); } @Override protected Flowable runLiveImpl(InvocationContext invocationContext) { - return llmFlow.runLive(invocationContext).doOnNext(this::maybeSaveOutputToState); + return llmFlow.runLive(invocationContext); } /** diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index fffeab698..115395b72 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -16,6 +16,11 @@ package com.google.adk.flows.llmflows; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.agents.ActiveStreamingTool; import com.google.adk.agents.BaseAgent; import com.google.adk.agents.CallbackContext; @@ -28,6 +33,7 @@ import com.google.adk.agents.ReadonlyContext; import com.google.adk.agents.RunConfig.StreamingMode; import com.google.adk.events.Event; +import com.google.adk.events.EventActions; import com.google.adk.flows.BaseFlow; import com.google.adk.flows.llmflows.RequestProcessor.RequestProcessingResult; import com.google.adk.flows.llmflows.ResponseProcessor.ResponseProcessingResult; @@ -41,7 +47,9 @@ import com.google.adk.tools.ToolContext; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; +import com.google.genai.types.Content; import com.google.genai.types.FunctionResponse; +import com.google.genai.types.Part; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.StatusCode; import io.opentelemetry.context.Context; @@ -54,7 +62,9 @@ import io.reactivex.rxjava3.observers.DisposableCompletableObserver; import io.reactivex.rxjava3.schedulers.Schedulers; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.concurrent.atomic.AtomicReference; @@ -64,6 +74,7 @@ /** A basic flow that calls the LLM in a loop until a final response is generated. */ public abstract class BaseLlmFlow implements BaseFlow { private static final Logger logger = LoggerFactory.getLogger(BaseLlmFlow.class); + private static final ObjectMapper objectMapper = new ObjectMapper(); protected final List requestProcessors; protected final List responseProcessors; @@ -349,14 +360,19 @@ private Single handleAfterModelCallback( Maybe callbackResult = Maybe.defer( - () -> - Flowable.fromIterable(callbacks) - .concatMapMaybe( - callback -> + () -> { + Single currentResponse = Single.just(llmResponse); + for (AfterModelCallback callback : callbacks) { + currentResponse = + currentResponse.flatMap( + resp -> callback - .call(callbackContext, llmResponse) - .compose(Tracing.withContext(currentContext))) - .firstElement()); + .call(callbackContext, resp) + .compose(Tracing.withContext(currentContext)) + .defaultIfEmpty(resp)); + } + return currentResponse.toMaybe(); + }); return pluginResult.switchIfEmpty(callbackResult).defaultIfEmpty(llmResponse); } @@ -461,14 +477,37 @@ public Flowable run(InvocationContext invocationContext) { private Flowable run( Context spanContext, InvocationContext invocationContext, int stepsCompleted) { - Flowable currentStepEvents = runOneStep(spanContext, invocationContext).cache(); + Flowable currentStepEvents = runOneStep(spanContext, invocationContext); + + Flowable processedEvents = + currentStepEvents + .concatMap( + event -> { + if (invocationContext.session().events().stream() + .anyMatch(e -> e.id() != null && e.id().equals(event.id()))) { + logger.debug("Event {} already in session, skipping append", event.id()); + return Flowable.just(event); + } + return invocationContext + .sessionService() + .appendEvent(invocationContext.session(), event) + .flatMap( + registeredEvent -> + invocationContext + .pluginManager() + .onEventCallback(invocationContext, registeredEvent) + .defaultIfEmpty(registeredEvent)) + .toFlowable(); + }) + .cache(); + if (stepsCompleted + 1 >= maxSteps) { logger.debug("Ending flow execution because max steps reached."); - return currentStepEvents; + return processedEvents; } - return currentStepEvents.concatWith( - currentStepEvents + return processedEvents.concatWith( + processedEvents .toList() .flatMapPublisher( eventList -> { @@ -685,22 +724,75 @@ private Flowable buildPostprocessingEvents( Event modelResponseEvent = buildModelResponseEvent(baseEventForLlmResponse, llmRequest, updatedResponse); - if (modelResponseEvent.functionCalls().isEmpty()) { - return processorEvents.concatWith(Flowable.just(modelResponseEvent)); + + if (context.agent() instanceof LlmAgent agent) { + Optional outputKeyOpt = agent.outputKey(); + if (outputKeyOpt.isPresent() && modelResponseEvent.content().isPresent()) { + Content content = modelResponseEvent.content().get(); + Map extractedDelta = new HashMap<>(); + List cleanParts = new ArrayList<>(); + boolean metadataFound = false; + for (Part part : content.parts().orElse(ImmutableList.of())) { + if (part.inlineData().isPresent() + && part.inlineData() + .get() + .mimeType() + .orElse("") + .equals("application/json+metadata")) { + metadataFound = true; + byte[] data = part.inlineData().get().data().orElse(null); + if (data != null) { + String json = new String(data, UTF_8); + try { + Map metadata = + objectMapper.readValue(json, new TypeReference>() {}); + extractedDelta.putAll(metadata); + } catch (JsonProcessingException e) { + logger.error("Failed to parse metadata from inlineData", e); + } + } + } else { + cleanParts.add(part); + } + } + + if (metadataFound) { + Event.Builder updatedEventBuilder = modelResponseEvent.toBuilder(); + Content newContent = + Content.builder().role(content.role().orElse("model")).parts(cleanParts).build(); + updatedEventBuilder.content(newContent); + + if (!extractedDelta.isEmpty() && modelResponseEvent.finalResponse()) { + Map newStateDelta = + new HashMap<>(modelResponseEvent.actions().stateDelta()); + newStateDelta.putAll(extractedDelta); + EventActions updatedActions = + modelResponseEvent.actions().toBuilder().stateDelta(newStateDelta).build(); + updatedEventBuilder.actions(updatedActions); + } + modelResponseEvent = updatedEventBuilder.build(); + } + } + } + final Event finalModelResponseEvent = modelResponseEvent; + + if (finalModelResponseEvent.functionCalls().isEmpty()) { + return processorEvents.concatWith(Flowable.just(finalModelResponseEvent)); } Flowable functionEvents; try (Scope scope = parentContext.makeCurrent()) { Maybe maybeFunctionResponseEvent = context.runConfig().streamingMode() == StreamingMode.BIDI - ? Functions.handleFunctionCallsLive(context, modelResponseEvent, llmRequest.tools()) - : Functions.handleFunctionCalls(context, modelResponseEvent, llmRequest.tools()); + ? Functions.handleFunctionCallsLive( + context, finalModelResponseEvent, llmRequest.tools()) + : Functions.handleFunctionCalls(context, finalModelResponseEvent, llmRequest.tools()); functionEvents = maybeFunctionResponseEvent.flatMapPublisher( functionResponseEvent -> { Optional toolConfirmationEvent = Functions.generateRequestConfirmationEvent( - context, modelResponseEvent, functionResponseEvent); + context, finalModelResponseEvent, functionResponseEvent); List events = new ArrayList<>(); toolConfirmationEvent.ifPresent(events::add); events.add(functionResponseEvent); diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index bc810f28f..8f831da7e 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -47,6 +47,7 @@ import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.disposables.Disposable; import io.reactivex.rxjava3.functions.Function; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; @@ -70,8 +71,17 @@ public final class Functions { private static final Logger logger = LoggerFactory.getLogger(Functions.class); /** Generates a unique ID for a function call. */ - public static String generateClientFunctionCallId() { - return AF_FUNCTION_CALL_ID_PREFIX + UUID.randomUUID(); + public static String generateClientFunctionCallId( + String salt, FunctionCall functionCall, int sequenceNumber) { + String source = + salt + + "-" + + functionCall.name().orElse("") + + functionCall.args().orElse(ImmutableMap.of()).toString() + + "-" + + sequenceNumber; + return AF_FUNCTION_CALL_ID_PREFIX + + UUID.nameUUIDFromBytes(source.getBytes(StandardCharsets.UTF_8)).toString(); } /** @@ -95,12 +105,17 @@ public static void populateClientFunctionCallId(Event modelResponseEvent) { List newParts = new ArrayList<>(); boolean modified = false; + int counter = 0; for (Part part : originalParts) { if (part.functionCall().isPresent()) { FunctionCall functionCall = part.functionCall().get(); if (functionCall.id().isEmpty() || functionCall.id().get().isEmpty()) { FunctionCall updatedFunctionCall = - functionCall.toBuilder().id(generateClientFunctionCallId()).build(); + functionCall.toBuilder() + .id( + generateClientFunctionCallId( + modelResponseEvent.id(), functionCall, counter++)) + .build(); newParts.add(part.toBuilder().functionCall(updatedFunctionCall).build()); modified = true; } else { @@ -626,7 +641,7 @@ private static Event buildResponseEvent( .build(); return Event.builder() - .id(Event.generateEventId()) + .id(toolContext.functionCallId().orElseGet(Event::generateEventId)) .invocationId(invocationContext.invocationId()) .author(invocationContext.agent().name()) .branch(invocationContext.branch().orElse(null)) @@ -662,7 +677,7 @@ public static Optional generateRequestConfirmationEvent( .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)) .entrySet()) { - FunctionCall requestConfirmationFunctionCall = + FunctionCall.Builder builder = FunctionCall.builder() .name(REQUEST_CONFIRMATION_FUNCTION_CALL_NAME) .args( @@ -670,8 +685,10 @@ public static Optional generateRequestConfirmationEvent( "originalFunctionCall", functionCallsById.get(entry.getKey()), "toolConfirmation", - entry.getValue())) - .id(generateClientFunctionCallId()) + entry.getValue())); + FunctionCall requestConfirmationFunctionCall = + builder + .id(generateClientFunctionCallId(functionResponseEvent.id(), builder.build(), 0)) .build(); longRunningToolIds.add(requestConfirmationFunctionCall.id().get()); @@ -687,6 +704,7 @@ public static Optional generateRequestConfirmationEvent( return Optional.of( Event.builder() + .id(Event.generateEventId()) .invocationId(invocationContext.invocationId()) .author(invocationContext.agent().name()) .branch(invocationContext.branch().orElse(null)) diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 44a281f72..ce2bbbc1d 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -68,9 +68,12 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** The main class for the GenAI Agents runner. */ public class Runner { + private static final Logger logger = LoggerFactory.getLogger(Runner.class); private final BaseAgent agent; private final String appName; private final BaseArtifactService artifactService; @@ -570,19 +573,28 @@ private Flowable runAgentWithUpdatedSession( .agent() .runAsync(contextWithUpdatedSession) .concatMap( - agentEvent -> - this.sessionService - .appendEvent(updatedSession, agentEvent) - .flatMap( - registeredEvent -> { - // TODO: remove this hack after deprecating runAsync with Session. - copySessionStates(updatedSession, initialContext.session()); - return contextWithUpdatedSession - .pluginManager() - .onEventCallback(contextWithUpdatedSession, registeredEvent) - .defaultIfEmpty(registeredEvent); - }) - .toFlowable()); + agentEvent -> { + // TODO: remove this hack after deprecating runAsync with Session. + copySessionStates(updatedSession, initialContext.session()); + + // TODO: b/502182243 - Investigate if appendEvent should be made idempotent in + // SessionService to avoid this check. + if (updatedSession.events().stream() + .anyMatch(e -> e.id() != null && e.id().equals(agentEvent.id()))) { + logger.debug("Event {} already in session, skipping append", agentEvent.id()); + return Flowable.just(agentEvent); + } + return this.sessionService + .appendEvent(updatedSession, agentEvent) + .flatMap( + registeredEvent -> { + return contextWithUpdatedSession + .pluginManager() + .onEventCallback(contextWithUpdatedSession, registeredEvent) + .defaultIfEmpty(registeredEvent); + }) + .toFlowable(); + }); // If beforeRunCallback returns content, emit it and skip agent Context capturedContext = Context.current(); diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index e40a83aa0..82e38b225 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -184,6 +184,17 @@ public void testRun_withoutOutputKey_doesNotSaveState() { assertThat(events.get(0).actions().stateDelta()).isEmpty(); } + @Test + public void runAsync_withOutputKeyAndEmptyResponse_doesNotSaveState() { + TestLlm testLlm = createTestLlm(LlmResponse.builder().build()); + LlmAgent agent = createTestAgentBuilder(testLlm).outputKey("myOutput").build(); + InvocationContext invocationContext = createInvocationContext(agent); + + List events = agent.runAsync(invocationContext).toList().blockingGet(); + + assertThat(events).isEmpty(); + } + @Test public void run_withToolsAndMaxSteps_stopsAfterMaxSteps() { ImmutableMap echoArgs = ImmutableMap.of("arg", "value"); diff --git a/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java index 1b8de4e4f..49c252536 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java @@ -326,6 +326,74 @@ public void populateClientFunctionCallId_withExistingId_noChange() { assertThat(event.content().get().parts().get().get(0).functionCall().get().id()).hasValue(id); } + @Test + public void generateClientFunctionCallId_returnsConsistentId() { + FunctionCall functionCall = + FunctionCall.builder().name("echo_tool").args(ImmutableMap.of("key", "value")).build(); + + String id1 = Functions.generateClientFunctionCallId("test-event-id", functionCall, 0); + String id2 = Functions.generateClientFunctionCallId("test-event-id", functionCall, 0); + + assertThat(id1).isEqualTo(id2); + assertThat(id1).startsWith("adk-"); + } + + @Test + public void generateClientFunctionCallId_returnsDifferentIdForDifferentArgs() { + FunctionCall functionCall1 = + FunctionCall.builder().name("echo_tool").args(ImmutableMap.of("key", "value1")).build(); + FunctionCall functionCall2 = + FunctionCall.builder().name("echo_tool").args(ImmutableMap.of("key", "value2")).build(); + + String id1 = Functions.generateClientFunctionCallId("test-event-id", functionCall1, 0); + String id2 = Functions.generateClientFunctionCallId("test-event-id", functionCall2, 0); + + assertThat(id1).isNotEqualTo(id2); + } + + @Test + public void generateClientFunctionCallId_returnsDifferentIdForDifferentSequenceNumbers() { + FunctionCall functionCall = + FunctionCall.builder().name("echo_tool").args(ImmutableMap.of("key", "value")).build(); + + String id1 = Functions.generateClientFunctionCallId("test-event-id", functionCall, 0); + String id2 = Functions.generateClientFunctionCallId("test-event-id", functionCall, 1); + + assertThat(id1).isNotEqualTo(id2); + } + + @Test + public void generateClientFunctionCallId_returnsDifferentIdForDifferentNames() { + FunctionCall functionCall1 = + FunctionCall.builder().name("echo_tool1").args(ImmutableMap.of("key", "value")).build(); + FunctionCall functionCall2 = + FunctionCall.builder().name("echo_tool2").args(ImmutableMap.of("key", "value")).build(); + + String id1 = Functions.generateClientFunctionCallId("test-event-id", functionCall1, 0); + String id2 = Functions.generateClientFunctionCallId("test-event-id", functionCall2, 0); + + assertThat(id1).isNotEqualTo(id2); + } + + @Test + public void populateClientFunctionCallId_withMultipleCalls_populatesDifferentIds() { + Event event = + createEvent("event").toBuilder() + .content( + Content.fromParts( + Part.fromFunctionCall("echo_tool", ImmutableMap.of("key", "value1")), + Part.fromFunctionCall("echo_tool", ImmutableMap.of("key", "value1")))) + .build(); + + Functions.populateClientFunctionCallId(event); + + var parts = event.content().get().parts().get(); + String id1 = parts.get(0).functionCall().get().id().get(); + String id2 = parts.get(1).functionCall().get().id().get(); + + assertThat(id1).isNotEqualTo(id2); + } + @Test public void getAskUserConfirmationFunctionCalls_eventWithNoContent_returnsEmptyList() { assertThat(Functions.getAskUserConfirmationFunctionCalls(EVENT_WITH_NO_CONTENT)).isEmpty(); diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index ff75c97b0..a79e561d2 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -46,9 +46,12 @@ import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.events.Event; import com.google.adk.flows.llmflows.Functions; +import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.adk.plugins.BasePlugin; import com.google.adk.sessions.BaseSessionService; +import com.google.adk.sessions.GetSessionConfig; +import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.adk.sessions.SessionKey; import com.google.adk.summarizer.EventsCompactionConfig; @@ -80,6 +83,7 @@ import java.time.Instant; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.UUID; @@ -588,12 +592,22 @@ public void onToolErrorCallback_error() { @Test public void onEventCallback_success() { when(plugin.onEventCallback(any(), any())) - .thenReturn(Maybe.just(TestUtils.createEvent("form plugin"))); + .thenAnswer( + invocation -> { + Event event = invocation.getArgument(1); + return Maybe.just( + Event.builder() + .id(event.id()) + .invocationId(event.invocationId()) + .author("model") + .content(createContent("from plugin")) + .build()); + }); List events = runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet(); - assertThat(simplifyEvents(events)).containsExactly("author: content for event form plugin"); + assertThat(simplifyEvents(events)).containsExactly("model: from plugin"); verify(plugin).onEventCallback(any(), any()); } @@ -1686,4 +1700,168 @@ public void runner_executesSaveArtifactFlow() { // agent was run assertThat(simplifyEvents(events.values())).containsExactly("test agent: from llm"); } + + @Test + public void runAsync_ensuresSequentialConsistencyForTools() { + // Arrange + TestLlm testLlm = + createTestLlm( + createFunctionCallLlmResponse("call_1", "tool1", ImmutableMap.of("arg", "value1")), + createTextLlmResponse("Final response")); + + LlmAgent agent = + createTestAgentBuilder(testLlm) + .tools( + ImmutableList.of( + FunctionTool.create(RaceConditionTools.class, "tool1"), + FunctionTool.create(RaceConditionTools.class, "tool2"))) + .build(); + + BaseSessionService delegate = new InMemorySessionService(); + BaseSessionService delayedSessionService = createDelayedSessionService(delegate, 0); + + Runner runner = + Runner.builder() + .app(App.builder().name("test").rootAgent(agent).build()) + .sessionService(delayedSessionService) + .build(); + Session session = runner.sessionService().createSession("test", "user").blockingGet(); + + // Act + var unused = + runner + .runAsync("user", session.id(), Content.fromParts(Part.fromText("start"))) + .toList() + .blockingGet(); + + // Assert + ImmutableList requests = ImmutableList.copyOf(testLlm.getRequests()); + assertThat(requests).hasSize(2); + + // Second request should contain the result of tool1 + LlmRequest secondRequest = requests.get(1); + List history = secondRequest.contents(); + + boolean foundToolResponse = + history.stream() + .flatMap(content -> content.parts().stream().flatMap(List::stream)) + .filter(part -> part.functionResponse().isPresent()) + .map(part -> part.functionResponse().get()) + .anyMatch( + response -> + response.name().orElse("").equals("tool1") + && response + .response() + .map(r -> Objects.equals(r, ImmutableMap.of("result", "result_value1"))) + .orElse(false)); + + assertThat(foundToolResponse).isTrue(); + } + + @SuppressWarnings("unchecked") // Suppressed because of raw types in mockito matchers. + private static BaseSessionService createDelayedSessionService( + BaseSessionService delegate, long delayMs) { + BaseSessionService delayedSessionService = mock(BaseSessionService.class); + when(delayedSessionService.createSession(anyString(), anyString(), any(Map.class), anyString())) + .thenAnswer( + inv -> + delegate.createSession( + (String) inv.getArgument(0), + (String) inv.getArgument(1), + (Map) inv.getArgument(2), + (String) inv.getArgument(3))); + when(delayedSessionService.createSession(anyString(), anyString())) + .thenAnswer( + inv -> + delegate.createSession((String) inv.getArgument(0), (String) inv.getArgument(1))); + when(delayedSessionService.getSession(anyString(), anyString(), anyString(), any())) + .thenAnswer( + inv -> + delegate.getSession( + (String) inv.getArgument(0), + (String) inv.getArgument(1), + (String) inv.getArgument(2), + (Optional) inv.getArgument(3))); + when(delayedSessionService.appendEvent(any(), any())) + .thenAnswer( + inv -> + delegate + .appendEvent(inv.getArgument(0), inv.getArgument(1)) + .delay(delayMs, MILLISECONDS)); + return delayedSessionService; + } + + public static class RaceConditionTools { + private RaceConditionTools() {} + + public static ImmutableMap tool1(String arg) { + return ImmutableMap.of("result", "result_" + arg); + } + + public static ImmutableMap tool2(String input) { + return ImmutableMap.of("status", "received_" + input); + } + } + + @Test + public void runAsync_withExistingEvents_appendsAgentResponse() { + // Create a session with an existing event + Event existingEvent = + Event.builder() + .id("existing_event_id") + .invocationId("inv1") + .author("user") + .content(Content.fromParts(Part.fromText("existing content"))) + .build(); + session = runner.sessionService().createSession("test", "user").blockingGet(); + var unused = runner.sessionService().appendEvent(session, existingEvent).blockingGet(); + + // Run the agent + var unusedFix = + runner.runAsync("user", session.id(), createContent("new message")).toList().blockingGet(); + + // Verify the agent response was appended + Session updatedSession = + runner + .sessionService() + .getSession("test", "user", session.id(), Optional.empty()) + .blockingGet(); + // Expected events: existingEvent, user message event, agent response event. + assertThat(updatedSession.events()).hasSize(3); + } + + @Test + public void runAsync_skipsDuplicateEvents() throws Exception { + Event existingEvent = + Event.builder() + .id("duplicate_event_id") + .invocationId("inv1") + .author("user") + .content(Content.fromParts(Part.fromText("existing content"))) + .build(); + session = runner.sessionService().createSession("test", "user").blockingGet(); + var unused = runner.sessionService().appendEvent(session, existingEvent).blockingGet(); + + BaseAgent mockAgent = TestUtils.createSubAgent("test agent", existingEvent); + Runner mockRunner = + Runner.builder() + .agent(mockAgent) + .appName("test") + .sessionService(runner.sessionService()) + .build(); + + var unused2 = + mockRunner + .runAsync("user", session.id(), createContent("new message")) + .toList() + .blockingGet(); + + Session updatedSession = + runner + .sessionService() + .getSession("test", "user", session.id(), Optional.empty()) + .blockingGet(); + // Expected events: existingEvent, user message event. Duplicate agent response is skipped. + assertThat(updatedSession.events()).hasSize(2); + } }