diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index cfbadb9fe..8e654485c 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -144,11 +144,15 @@ protected Flowable postprocess( }) .map(ResponseProcessingResult::updatedResponse); } + Context parentContext = Context.current(); return currentLlmResponse.flatMapPublisher( - updatedResponse -> - buildPostprocessingEvents( - updatedResponse, eventIterables, context, baseEventForLlmResponse, llmRequest)); + updatedResponse -> { + try (Scope scope = parentContext.makeCurrent()) { + return buildPostprocessingEvents( + updatedResponse, eventIterables, context, baseEventForLlmResponse, llmRequest); + } + }); } /** @@ -160,7 +164,10 @@ protected Flowable postprocess( * callbacks. Callbacks should not rely on its ID if they create their own separate events. */ private Flowable callLlm( - InvocationContext context, LlmRequest llmRequest, Event eventForCallbackUsage) { + InvocationContext context, + LlmRequest llmRequest, + Event eventForCallbackUsage, + Context parentTracingContext) { LlmAgent agent = (LlmAgent) context.agent(); LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder(); @@ -180,7 +187,7 @@ private Flowable callLlm( Span llmCallSpan = Tracing.getTracer() .spanBuilder("call_llm") - .setParent(Context.current()) + .setParent(parentTracingContext) .startSpan(); try (Scope scope = llmCallSpan.makeCurrent()) { @@ -333,6 +340,7 @@ private Single handleAfterModelCallback( * @throws IllegalStateException if a transfer agent is specified but not found. */ private Flowable runOneStep(InvocationContext context) { + Context parentContext = Context.current(); AtomicReference llmRequestRef = new AtomicReference<>(LlmRequest.builder().build()); Flowable preprocessEvents = preprocess(context, llmRequestRef); @@ -363,10 +371,12 @@ private Flowable runOneStep(InvocationContext context) { // events with fresh timestamp. mutableEventTemplate.setTimestamp(0L); - return callLlm(context, llmRequestAfterPreprocess, mutableEventTemplate) + return callLlm( + context, llmRequestAfterPreprocess, mutableEventTemplate, parentContext) .concatMap( - llmResponse -> - postprocess( + llmResponse -> { + try (Scope scope = parentContext.makeCurrent()) { + return postprocess( context, mutableEventTemplate, llmRequestAfterPreprocess, @@ -380,7 +390,9 @@ private Flowable runOneStep(InvocationContext context) { + " next LlmResponse", oldId, mutableEventTemplate.id()); - })) + }); + } + }) .concatMap( event -> { Flowable postProcessedEvents = Flowable.just(event); @@ -421,6 +433,7 @@ private Flowable run(InvocationContext invocationContext, int stepsComple return currentStepEvents; } + Context parentContext = Context.current(); return currentStepEvents.concatWith( currentStepEvents .toList() @@ -435,7 +448,12 @@ private Flowable run(InvocationContext invocationContext, int stepsComple return Flowable.empty(); } else { logger.debug("Continuing to next step of the flow."); - return Flowable.defer(() -> run(invocationContext, stepsCompleted + 1)); + return Flowable.defer( + () -> { + try (Scope scope = parentContext.makeCurrent()) { + return run(invocationContext, stepsCompleted + 1); + } + }); } })); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 3bb57faee..a6fb74d88 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -245,6 +245,7 @@ private static Function> getFunctionCallMapper( Map tools, Map toolConfirmations, boolean isLive) { + Context parentContext = Context.current(); return functionCall -> { BaseTool tool = tools.get(functionCall.name().get()); ToolContext toolContext = @@ -259,14 +260,19 @@ private static Function> getFunctionCallMapper( maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) .switchIfEmpty( Maybe.defer( - () -> - isLive + () -> { + try (Scope scope = parentContext.makeCurrent()) { + return isLive ? processFunctionLive( invocationContext, tool, toolContext, functionCall, functionArgs) - : callTool(tool, functionArgs, toolContext))); + : callTool(tool, functionArgs, toolContext); + } + })); - return postProcessFunctionResult( - maybeFunctionResult, invocationContext, tool, functionArgs, toolContext, isLive); + try (Scope scope = parentContext.makeCurrent()) { + return postProcessFunctionResult( + maybeFunctionResult, invocationContext, tool, functionArgs, toolContext, isLive); + } }; } @@ -372,6 +378,7 @@ private static Maybe postProcessFunctionResult( Map functionArgs, ToolContext toolContext, boolean isLive) { + Context parentContext = Context.current(); return maybeFunctionResult .map(Optional::of) .defaultIfEmpty(Optional.empty()) @@ -393,14 +400,17 @@ private static Maybe postProcessFunctionResult( .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) .flatMapMaybe( finalOptionalResult -> { - Map finalFunctionResult = finalOptionalResult.orElse(null); - if (tool.longRunning() && finalFunctionResult == null) { - return Maybe.empty(); + try (Scope scope = parentContext.makeCurrent()) { + Map finalFunctionResult = + finalOptionalResult.orElse(null); + if (tool.longRunning() && finalFunctionResult == null) { + return Maybe.empty(); + } + Event functionResponseEvent = + buildResponseEvent( + tool, finalFunctionResult, toolContext, invocationContext); + return Maybe.just(functionResponseEvent); } - Event functionResponseEvent = - buildResponseEvent( - tool, finalFunctionResult, toolContext, invocationContext); - return Maybe.just(functionResponseEvent); }); }); } @@ -552,12 +562,13 @@ private static Maybe> maybeInvokeAfterToolCall( private static Maybe> callTool( BaseTool tool, Map args, ToolContext toolContext) { Tracer tracer = Tracing.getTracer(); + Context parentContext = Context.current(); return Maybe.defer( () -> { Span span = tracer .spanBuilder("tool_call [" + tool.name() + "]") - .setParent(Context.current()) + .setParent(parentContext) .startSpan(); try (Scope scope = span.makeCurrent()) { Tracing.traceToolCall(args);