Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,15 @@ protected Flowable<Event> postprocess(
})
.map(ResponseProcessingResult::updatedResponse);
}
Context parentContext = Context.current();

return currentLlmResponse.flatMapPublisher(
updatedResponse ->
buildPostprocessingEvents(
updatedResponse, eventIterables, context, baseEventForLlmResponse, llmRequest));
updatedResponse -> {
try (Scope scope = parentContext.makeCurrent()) {
return buildPostprocessingEvents(
updatedResponse, eventIterables, context, baseEventForLlmResponse, llmRequest);
}
});
}

/**
Expand All @@ -160,7 +164,10 @@ protected Flowable<Event> postprocess(
* callbacks. Callbacks should not rely on its ID if they create their own separate events.
*/
private Flowable<LlmResponse> callLlm(
InvocationContext context, LlmRequest llmRequest, Event eventForCallbackUsage) {
InvocationContext context,
LlmRequest llmRequest,
Event eventForCallbackUsage,
Context parentTracingContext) {
LlmAgent agent = (LlmAgent) context.agent();

LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder();
Expand All @@ -180,7 +187,7 @@ private Flowable<LlmResponse> callLlm(
Span llmCallSpan =
Tracing.getTracer()
.spanBuilder("call_llm")
.setParent(Context.current())
.setParent(parentTracingContext)
.startSpan();

try (Scope scope = llmCallSpan.makeCurrent()) {
Expand Down Expand Up @@ -333,6 +340,7 @@ private Single<LlmResponse> handleAfterModelCallback(
* @throws IllegalStateException if a transfer agent is specified but not found.
*/
private Flowable<Event> runOneStep(InvocationContext context) {
Context parentContext = Context.current();
AtomicReference<LlmRequest> llmRequestRef = new AtomicReference<>(LlmRequest.builder().build());
Flowable<Event> preprocessEvents = preprocess(context, llmRequestRef);

Expand Down Expand Up @@ -363,10 +371,12 @@ private Flowable<Event> runOneStep(InvocationContext context) {
// events with fresh timestamp.
mutableEventTemplate.setTimestamp(0L);

return callLlm(context, llmRequestAfterPreprocess, mutableEventTemplate)
return callLlm(
context, llmRequestAfterPreprocess, mutableEventTemplate, parentContext)
.concatMap(
llmResponse ->
postprocess(
llmResponse -> {
try (Scope scope = parentContext.makeCurrent()) {
return postprocess(
context,
mutableEventTemplate,
llmRequestAfterPreprocess,
Expand All @@ -380,7 +390,9 @@ private Flowable<Event> runOneStep(InvocationContext context) {
+ " next LlmResponse",
oldId,
mutableEventTemplate.id());
}))
});
}
})
.concatMap(
event -> {
Flowable<Event> postProcessedEvents = Flowable.just(event);
Expand Down Expand Up @@ -421,6 +433,7 @@ private Flowable<Event> run(InvocationContext invocationContext, int stepsComple
return currentStepEvents;
}

Context parentContext = Context.current();
return currentStepEvents.concatWith(
currentStepEvents
.toList()
Expand All @@ -435,7 +448,12 @@ private Flowable<Event> run(InvocationContext invocationContext, int stepsComple
return Flowable.empty();
} else {
logger.debug("Continuing to next step of the flow.");
return Flowable.defer(() -> run(invocationContext, stepsCompleted + 1));
return Flowable.defer(
() -> {
try (Scope scope = parentContext.makeCurrent()) {
return run(invocationContext, stepsCompleted + 1);
}
});
}
}));
}
Expand Down
37 changes: 24 additions & 13 deletions core/src/main/java/com/google/adk/flows/llmflows/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ private static Function<FunctionCall, Maybe<Event>> getFunctionCallMapper(
Map<String, BaseTool> tools,
Map<String, ToolConfirmation> toolConfirmations,
boolean isLive) {
Context parentContext = Context.current();
return functionCall -> {
BaseTool tool = tools.get(functionCall.name().get());
ToolContext toolContext =
Expand All @@ -259,14 +260,19 @@ private static Function<FunctionCall, Maybe<Event>> getFunctionCallMapper(
maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext)
.switchIfEmpty(
Maybe.defer(
() ->
isLive
() -> {
try (Scope scope = parentContext.makeCurrent()) {
return isLive
? processFunctionLive(
invocationContext, tool, toolContext, functionCall, functionArgs)
: callTool(tool, functionArgs, toolContext)));
: callTool(tool, functionArgs, toolContext);
}
}));

return postProcessFunctionResult(
maybeFunctionResult, invocationContext, tool, functionArgs, toolContext, isLive);
try (Scope scope = parentContext.makeCurrent()) {
return postProcessFunctionResult(
maybeFunctionResult, invocationContext, tool, functionArgs, toolContext, isLive);
}
};
}

Expand Down Expand Up @@ -372,6 +378,7 @@ private static Maybe<Event> postProcessFunctionResult(
Map<String, Object> functionArgs,
ToolContext toolContext,
boolean isLive) {
Context parentContext = Context.current();
return maybeFunctionResult
.map(Optional::of)
.defaultIfEmpty(Optional.empty())
Expand All @@ -393,14 +400,17 @@ private static Maybe<Event> postProcessFunctionResult(
.defaultIfEmpty(Optional.ofNullable(initialFunctionResult))
.flatMapMaybe(
finalOptionalResult -> {
Map<String, Object> finalFunctionResult = finalOptionalResult.orElse(null);
if (tool.longRunning() && finalFunctionResult == null) {
return Maybe.empty();
try (Scope scope = parentContext.makeCurrent()) {
Map<String, Object> finalFunctionResult =
finalOptionalResult.orElse(null);
if (tool.longRunning() && finalFunctionResult == null) {
return Maybe.empty();
}
Event functionResponseEvent =
buildResponseEvent(
tool, finalFunctionResult, toolContext, invocationContext);
return Maybe.just(functionResponseEvent);
}
Event functionResponseEvent =
buildResponseEvent(
tool, finalFunctionResult, toolContext, invocationContext);
return Maybe.just(functionResponseEvent);
});
});
}
Expand Down Expand Up @@ -552,12 +562,13 @@ private static Maybe<Map<String, Object>> maybeInvokeAfterToolCall(
private static Maybe<Map<String, Object>> callTool(
BaseTool tool, Map<String, Object> args, ToolContext toolContext) {
Tracer tracer = Tracing.getTracer();
Context parentContext = Context.current();
return Maybe.defer(
() -> {
Span span =
tracer
.spanBuilder("tool_call [" + tool.name() + "]")
.setParent(Context.current())
.setParent(parentContext)
.startSpan();
try (Scope scope = span.makeCurrent()) {
Tracing.traceToolCall(args);
Expand Down