Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,41 +64,15 @@ public Single<RequestProcessor.RequestProcessingResult> processRequest(
return Single.just(RequestProcessingResult.create(llmRequest, ImmutableList.of()));
}

int confirmationEventIndex = -1;
ImmutableMap<String, ToolConfirmation> responses = ImmutableMap.of();
// Search backwards for the most recent user event that contains request confirmation
// function responses.
for (int i = events.size() - 1; i >= 0; i--) {
Event event = events.get(i);
if (!Objects.equals(event.author(), "user") || event.functionResponses().isEmpty()) {
continue;
}

ImmutableMap<String, ToolConfirmation> confirmationsInEvent =
event.functionResponses().stream()
.filter(functionResponse -> functionResponse.id().isPresent())
.filter(
functionResponse ->
Objects.equals(
functionResponse.name().orElse(null),
REQUEST_CONFIRMATION_FUNCTION_CALL_NAME))
.map(this::maybeCreateToolConfirmationEntry)
.flatMap(Optional::stream)
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
if (!confirmationsInEvent.isEmpty()) {
responses = confirmationsInEvent;
confirmationEventIndex = i;
break;
}
}
if (responses.isEmpty()) {
Optional<ConfirmationResult> confirmationResult = findMostRecentConfirmations(events);
if (confirmationResult.isEmpty()) {
logger.trace("No request confirmation function responses found.");
return Single.just(RequestProcessingResult.create(llmRequest, ImmutableList.of()));
}

// Make them final to enable access from lambda expressions.
final int finalConfirmationEventIndex = confirmationEventIndex;
final ImmutableMap<String, ToolConfirmation> requestConfirmationFunctionResponses = responses;
int finalConfirmationEventIndex = confirmationResult.get().eventIndex();
ImmutableMap<String, ToolConfirmation> requestConfirmationFunctionResponses =
confirmationResult.get().responses();

// Search backwards from the event before confirmation for the corresponding
// request_confirmation function calls emitted by the model.
Expand Down Expand Up @@ -169,6 +143,34 @@ public Single<RequestProcessor.RequestProcessingResult> processRequest(
return Single.just(RequestProcessingResult.create(llmRequest, ImmutableList.of()));
}

private static Optional<ConfirmationResult> findMostRecentConfirmations(
ImmutableList<Event> events) {
// Search backwards for the most recent user event that contains request confirmation
// function responses.
for (int i = events.size() - 1; i >= 0; i--) {
Event event = events.get(i);
if (!Objects.equals(event.author(), "user") || event.functionResponses().isEmpty()) {
continue;
}

ImmutableMap<String, ToolConfirmation> confirmationsInEvent =
event.functionResponses().stream()
.filter(functionResponse -> functionResponse.id().isPresent())
.filter(
functionResponse ->
Objects.equals(
functionResponse.name().orElse(null),
REQUEST_CONFIRMATION_FUNCTION_CALL_NAME))
.map(RequestConfirmationLlmRequestProcessor::maybeCreateToolConfirmationEntry)
.flatMap(Optional::stream)
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
if (!confirmationsInEvent.isEmpty()) {
return Optional.of(new ConfirmationResult(confirmationsInEvent, i));
}
}
return Optional.empty();
}

private Optional<FunctionCall> getOriginalFunctionCall(FunctionCall functionCall) {
if (!functionCall.args().orElse(ImmutableMap.of()).containsKey(ORIGINAL_FUNCTION_CALL)) {
return Optional.empty();
Expand Down Expand Up @@ -220,7 +222,7 @@ private Maybe<Event> assembleEvent(
invocationContext, functionCallEvent, toolsMap, toolConfirmations));
}

private Optional<Map.Entry<String, ToolConfirmation>> maybeCreateToolConfirmationEntry(
private static Optional<Map.Entry<String, ToolConfirmation>> maybeCreateToolConfirmationEntry(
FunctionResponse functionResponse) {
Map<String, Object> responseMap = functionResponse.response().orElse(ImmutableMap.of());
if (responseMap.size() != 1 || !responseMap.containsKey("response")) {
Expand All @@ -242,4 +244,7 @@ private Optional<Map.Entry<String, ToolConfirmation>> maybeCreateToolConfirmatio

return Optional.empty();
}

private record ConfirmationResult(
ImmutableMap<String, ToolConfirmation> responses, int eventIndex) {}
}