Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@
import com.google.adk.events.EventCompaction;
import com.google.adk.sessions.BaseSessionService;
import com.google.adk.sessions.Session;
import com.google.common.collect.Lists;
import com.google.genai.types.GenerateContentResponseUsageMetadata;
import io.reactivex.rxjava3.core.Completable;
import io.reactivex.rxjava3.core.Maybe;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.ListIterator;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -38,6 +41,7 @@
* <li>Keeps the {@code retentionSize} most recent events raw.
* <li>Compacts all events that never compacted and older than the retained tail, including the
* most recent compaction event, into a new summary event.
* <li>Triggers compaction only if the prompt token count exceeds the {@code tokenThreshold}.
* <li>The new summary event is generated by the {@link BaseEventSummarizer}.
* <li>Appends this new summary event to the end of the event stream.
* </ul>
Expand All @@ -52,21 +56,52 @@ public final class TailRetentionEventCompactor implements EventCompactor {

private final BaseEventSummarizer summarizer;
private final int retentionSize;
private final int tokenThreshold;

public TailRetentionEventCompactor(BaseEventSummarizer summarizer, int retentionSize) {
public TailRetentionEventCompactor(
BaseEventSummarizer summarizer, int retentionSize, int tokenThreshold) {
checkArgument(tokenThreshold >= 0, "tokenThreshold must be non-negative");
checkArgument(retentionSize >= 0, "retentionSize must be non-negative");
this.summarizer = summarizer;
this.retentionSize = retentionSize;
this.tokenThreshold = tokenThreshold;
}

@Override
public Completable compact(Session session, BaseSessionService sessionService) {
checkArgument(summarizer != null, "Missing BaseEventSummarizer for event compaction");
logger.debug("Running tail retention event compaction for session {}", session.id());

return Completable.fromMaybe(
getCompactionEvents(session.events())
.flatMap(summarizer::summarizeEvents)
.flatMapSingle(e -> sessionService.appendEvent(session, e)));
return Maybe.just(session.events())
.filter(this::shouldCompact)
.flatMap(events -> getCompactionEvents(events))
.flatMap(summarizer::summarizeEvents)
.flatMapSingle(e -> sessionService.appendEvent(session, e))
.ignoreElement();
}

private boolean shouldCompact(List<Event> events) {
int count = getLatestPromptTokenCount(events).orElse(0);

// TODO b/480013930 - Add a way to estimate the prompt token if the usage metadata is not
// available.
if (count <= tokenThreshold) {
logger.debug(
"Skipping compaction. Prompt token count {} is within threshold {}",
count,
tokenThreshold);
return false;
}
return true;
}

private Optional<Integer> getLatestPromptTokenCount(List<Event> events) {
return Lists.reverse(events).stream()
.map(Event::usageMetadata)
.flatMap(Optional::stream)
.map(GenerateContentResponseUsageMetadata::promptTokenCount)
.flatMap(Optional::stream)
.findFirst();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.google.adk.summarizer;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.never;
Expand All @@ -30,6 +31,7 @@
import com.google.adk.sessions.Session;
import com.google.common.collect.ImmutableList;
import com.google.genai.types.Content;
import com.google.genai.types.GenerateContentResponseUsageMetadata;
import com.google.genai.types.Part;
import io.reactivex.rxjava3.core.Maybe;
import io.reactivex.rxjava3.core.Single;
Expand All @@ -52,15 +54,91 @@ public class TailRetentionEventCompactorTest {
@Mock private BaseEventSummarizer mockSummarizer;
@Captor private ArgumentCaptor<List<Event>> eventListCaptor;

@Test
public void constructor_negativeTokenThreshold_throwsException() {
assertThat(
assertThrows(
IllegalArgumentException.class,
() -> new TailRetentionEventCompactor(mockSummarizer, 2, -1)))
.hasMessageThat()
.contains("tokenThreshold must be non-negative");
}

@Test
public void constructor_negativeRetentionSize_throwsException() {
assertThat(
assertThrows(
IllegalArgumentException.class,
() -> new TailRetentionEventCompactor(mockSummarizer, -1, 100)))
.hasMessageThat()
.contains("retentionSize must be non-negative");
}

@Test
// TODO: b/480013930 - Add a test case for estimating the prompt token if the usage metadata is
// not available.
public void compaction_skippedWhenTokenUsageMissing() {
EventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2, 100);
ImmutableList<Event> events =
ImmutableList.of(
createEvent(1, "Event1"),
createEvent(2, "Retain1"),
createEvent(3, "Retain2")); // No usage metadata
Session session = Session.builder("id").events(events).build();

compactor.compact(session, mockSessionService).blockingSubscribe();

verify(mockSummarizer, never()).summarizeEvents(any());
verify(mockSessionService, never()).appendEvent(any(), any());
}

@Test
public void compaction_skippedWhenTokenUsageBelowThreshold() {
// Threshold is 300, usage is 200.
EventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2, 300);
ImmutableList<Event> events =
ImmutableList.of(
createEvent(1, "Event1"),
createEvent(2, "Retain1"),
withUsage(createEvent(3, "Retain2"), 200));
Session session = Session.builder("id").events(events).build();

compactor.compact(session, mockSessionService).blockingSubscribe();

verify(mockSummarizer, never()).summarizeEvents(any());
verify(mockSessionService, never()).appendEvent(any(), any());
}

@Test
public void compaction_happensWhenTokenUsageAboveThreshold() {
// Threshold is 300, usage is 400.
EventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2, 300);
Event event3 = withUsage(createEvent(3, "Retain2"), 400);
ImmutableList<Event> events =
ImmutableList.of(createEvent(1, "Event1"), createEvent(2, "Retain1"), event3);
Session session = Session.builder("id").events(events).build();
Event summaryEvent = createEvent(4, "Summary");

when(mockSummarizer.summarizeEvents(any())).thenReturn(Maybe.just(summaryEvent));
when(mockSessionService.appendEvent(any(), any())).thenReturn(Single.just(summaryEvent));

compactor.compact(session, mockSessionService).blockingSubscribe();

verify(mockSummarizer).summarizeEvents(any());
verify(mockSessionService).appendEvent(eq(session), eq(summaryEvent));
}

@Test
public void compact_notEnoughEvents_doesNothing() {
ImmutableList<Event> events =
ImmutableList.of(
createEvent(1, "Event1"), createEvent(2, "Event2"), createEvent(3, "Event3"));
createEvent(1, "Event1"),
createEvent(2, "Event2"),
withUsage(createEvent(3, "Event3"), 200));
Session session = Session.builder("id").events(events).build();

// Retention size 5 > 3 events
TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 5);
// Retention size 5 > 3 events. Token usage 200 > threshold 100.
TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 5, 100);

compactor.compact(session, mockSessionService).test().assertComplete();

Expand All @@ -73,14 +151,17 @@ public void compact_respectRetentionSize_summarizesCorrectEvents() {
// Retention size is 2.
ImmutableList<Event> events =
ImmutableList.of(
createEvent(1, "Event1"), createEvent(2, "Retain1"), createEvent(3, "Retain2"));
createEvent(1, "Event1"),
createEvent(2, "Retain1"),
withUsage(createEvent(3, "Retain2"), 200));
Session session = Session.builder("id").events(events).build();
Event compactedEvent = createCompactedEvent(1, 1, "Summary", 4);

when(mockSummarizer.summarizeEvents(any())).thenReturn(Maybe.just(compactedEvent));
when(mockSessionService.appendEvent(any(), any())).then(i -> Single.just(i.getArgument(1)));

TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2);
// Token usage 200 > threshold 100.
TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2, 100);

compactor.compact(session, mockSessionService).test().assertComplete();

Expand Down Expand Up @@ -121,14 +202,15 @@ public void compact_withRetainedEventsPhysicallyBeforeCompaction_includesThem()
createCompactedEvent(
/* startTimestamp= */ 1, /* endTimestamp= */ 2, "C1", /* eventTimestamp= */ 4),
createEvent(5, "E5"),
createEvent(6, "E6"));
withUsage(createEvent(6, "E6"), 200));
Session session = Session.builder("id").events(events).build();
Event compactedEvent = createCompactedEvent(1, 5, "Summary C1-E5", 7);

when(mockSummarizer.summarizeEvents(any())).thenReturn(Maybe.just(compactedEvent));
when(mockSessionService.appendEvent(any(), any())).then(i -> Single.just(i.getArgument(1)));

TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 1);
// Token usage 200 > threshold 100.
TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 1, 100);

compactor.compact(session, mockSessionService).test().assertComplete();

Expand Down Expand Up @@ -181,14 +263,15 @@ public void compact_withMultipleCompactionEvents_respectsCompactionBoundary() {
createEvent(7, "E7"),
createCompactedEvent(
/* startTimestamp= */ 1, /* endTimestamp= */ 3, "C2", /* eventTimestamp= */ 8),
createEvent(9, "E9"));
withUsage(createEvent(9, "E9"), 200));
Session session = Session.builder("id").events(events).build();
Event compactedEvent = createCompactedEvent(1, 4, "Summary C2-E4", 10);

when(mockSummarizer.summarizeEvents(any())).thenReturn(Maybe.just(compactedEvent));
when(mockSessionService.appendEvent(any(), any())).then(i -> Single.just(i.getArgument(1)));

TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 3);
// Token usage 200 > threshold 100.
TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 3, 100);

compactor.compact(session, mockSessionService).test().assertComplete();

Expand Down Expand Up @@ -224,6 +307,13 @@ private static String getPromptText(Event event) {
.orElseThrow();
}

private Event withUsage(Event event, int tokens) {
return event.toBuilder()
.usageMetadata(
GenerateContentResponseUsageMetadata.builder().promptTokenCount(tokens).build())
.build();
}

private Event createCompactedEvent(
long startTimestamp, long endTimestamp, String content, long eventTimestamp) {
return Event.builder()
Expand Down