diff --git a/core/src/main/java/com/google/adk/summarizer/TailRetentionEventCompactor.java b/core/src/main/java/com/google/adk/summarizer/TailRetentionEventCompactor.java index c13a49cc1..b084de860 100644 --- a/core/src/main/java/com/google/adk/summarizer/TailRetentionEventCompactor.java +++ b/core/src/main/java/com/google/adk/summarizer/TailRetentionEventCompactor.java @@ -22,12 +22,15 @@ import com.google.adk.events.EventCompaction; import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; +import com.google.common.collect.Lists; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.ListIterator; +import java.util.Optional; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -38,6 +41,7 @@ *
  • Keeps the {@code retentionSize} most recent events raw. *
  • Compacts all events that never compacted and older than the retained tail, including the * most recent compaction event, into a new summary event. + *
  • Triggers compaction only if the prompt token count exceeds the {@code tokenThreshold}. *
  • The new summary event is generated by the {@link BaseEventSummarizer}. *
  • Appends this new summary event to the end of the event stream. * @@ -52,10 +56,15 @@ public final class TailRetentionEventCompactor implements EventCompactor { private final BaseEventSummarizer summarizer; private final int retentionSize; + private final int tokenThreshold; - public TailRetentionEventCompactor(BaseEventSummarizer summarizer, int retentionSize) { + public TailRetentionEventCompactor( + BaseEventSummarizer summarizer, int retentionSize, int tokenThreshold) { + checkArgument(tokenThreshold >= 0, "tokenThreshold must be non-negative"); + checkArgument(retentionSize >= 0, "retentionSize must be non-negative"); this.summarizer = summarizer; this.retentionSize = retentionSize; + this.tokenThreshold = tokenThreshold; } @Override @@ -63,10 +72,36 @@ public Completable compact(Session session, BaseSessionService sessionService) { checkArgument(summarizer != null, "Missing BaseEventSummarizer for event compaction"); logger.debug("Running tail retention event compaction for session {}", session.id()); - return Completable.fromMaybe( - getCompactionEvents(session.events()) - .flatMap(summarizer::summarizeEvents) - .flatMapSingle(e -> sessionService.appendEvent(session, e))); + return Maybe.just(session.events()) + .filter(this::shouldCompact) + .flatMap(events -> getCompactionEvents(events)) + .flatMap(summarizer::summarizeEvents) + .flatMapSingle(e -> sessionService.appendEvent(session, e)) + .ignoreElement(); + } + + private boolean shouldCompact(List events) { + int count = getLatestPromptTokenCount(events).orElse(0); + + // TODO b/480013930 - Add a way to estimate the prompt token if the usage metadata is not + // available. + if (count <= tokenThreshold) { + logger.debug( + "Skipping compaction. Prompt token count {} is within threshold {}", + count, + tokenThreshold); + return false; + } + return true; + } + + private Optional getLatestPromptTokenCount(List events) { + return Lists.reverse(events).stream() + .map(Event::usageMetadata) + .flatMap(Optional::stream) + .map(GenerateContentResponseUsageMetadata::promptTokenCount) + .flatMap(Optional::stream) + .findFirst(); } /** diff --git a/core/src/test/java/com/google/adk/summarizer/TailRetentionEventCompactorTest.java b/core/src/test/java/com/google/adk/summarizer/TailRetentionEventCompactorTest.java index b4a6c3474..3260fbe1e 100644 --- a/core/src/test/java/com/google/adk/summarizer/TailRetentionEventCompactorTest.java +++ b/core/src/test/java/com/google/adk/summarizer/TailRetentionEventCompactorTest.java @@ -17,6 +17,7 @@ package com.google.adk.summarizer; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.never; @@ -30,6 +31,7 @@ import com.google.adk.sessions.Session; import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; @@ -52,15 +54,91 @@ public class TailRetentionEventCompactorTest { @Mock private BaseEventSummarizer mockSummarizer; @Captor private ArgumentCaptor> eventListCaptor; + @Test + public void constructor_negativeTokenThreshold_throwsException() { + assertThat( + assertThrows( + IllegalArgumentException.class, + () -> new TailRetentionEventCompactor(mockSummarizer, 2, -1))) + .hasMessageThat() + .contains("tokenThreshold must be non-negative"); + } + + @Test + public void constructor_negativeRetentionSize_throwsException() { + assertThat( + assertThrows( + IllegalArgumentException.class, + () -> new TailRetentionEventCompactor(mockSummarizer, -1, 100))) + .hasMessageThat() + .contains("retentionSize must be non-negative"); + } + + @Test + // TODO: b/480013930 - Add a test case for estimating the prompt token if the usage metadata is + // not available. + public void compaction_skippedWhenTokenUsageMissing() { + EventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2, 100); + ImmutableList events = + ImmutableList.of( + createEvent(1, "Event1"), + createEvent(2, "Retain1"), + createEvent(3, "Retain2")); // No usage metadata + Session session = Session.builder("id").events(events).build(); + + compactor.compact(session, mockSessionService).blockingSubscribe(); + + verify(mockSummarizer, never()).summarizeEvents(any()); + verify(mockSessionService, never()).appendEvent(any(), any()); + } + + @Test + public void compaction_skippedWhenTokenUsageBelowThreshold() { + // Threshold is 300, usage is 200. + EventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2, 300); + ImmutableList events = + ImmutableList.of( + createEvent(1, "Event1"), + createEvent(2, "Retain1"), + withUsage(createEvent(3, "Retain2"), 200)); + Session session = Session.builder("id").events(events).build(); + + compactor.compact(session, mockSessionService).blockingSubscribe(); + + verify(mockSummarizer, never()).summarizeEvents(any()); + verify(mockSessionService, never()).appendEvent(any(), any()); + } + + @Test + public void compaction_happensWhenTokenUsageAboveThreshold() { + // Threshold is 300, usage is 400. + EventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2, 300); + Event event3 = withUsage(createEvent(3, "Retain2"), 400); + ImmutableList events = + ImmutableList.of(createEvent(1, "Event1"), createEvent(2, "Retain1"), event3); + Session session = Session.builder("id").events(events).build(); + Event summaryEvent = createEvent(4, "Summary"); + + when(mockSummarizer.summarizeEvents(any())).thenReturn(Maybe.just(summaryEvent)); + when(mockSessionService.appendEvent(any(), any())).thenReturn(Single.just(summaryEvent)); + + compactor.compact(session, mockSessionService).blockingSubscribe(); + + verify(mockSummarizer).summarizeEvents(any()); + verify(mockSessionService).appendEvent(eq(session), eq(summaryEvent)); + } + @Test public void compact_notEnoughEvents_doesNothing() { ImmutableList events = ImmutableList.of( - createEvent(1, "Event1"), createEvent(2, "Event2"), createEvent(3, "Event3")); + createEvent(1, "Event1"), + createEvent(2, "Event2"), + withUsage(createEvent(3, "Event3"), 200)); Session session = Session.builder("id").events(events).build(); - // Retention size 5 > 3 events - TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 5); + // Retention size 5 > 3 events. Token usage 200 > threshold 100. + TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 5, 100); compactor.compact(session, mockSessionService).test().assertComplete(); @@ -73,14 +151,17 @@ public void compact_respectRetentionSize_summarizesCorrectEvents() { // Retention size is 2. ImmutableList events = ImmutableList.of( - createEvent(1, "Event1"), createEvent(2, "Retain1"), createEvent(3, "Retain2")); + createEvent(1, "Event1"), + createEvent(2, "Retain1"), + withUsage(createEvent(3, "Retain2"), 200)); Session session = Session.builder("id").events(events).build(); Event compactedEvent = createCompactedEvent(1, 1, "Summary", 4); when(mockSummarizer.summarizeEvents(any())).thenReturn(Maybe.just(compactedEvent)); when(mockSessionService.appendEvent(any(), any())).then(i -> Single.just(i.getArgument(1))); - TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2); + // Token usage 200 > threshold 100. + TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2, 100); compactor.compact(session, mockSessionService).test().assertComplete(); @@ -121,14 +202,15 @@ public void compact_withRetainedEventsPhysicallyBeforeCompaction_includesThem() createCompactedEvent( /* startTimestamp= */ 1, /* endTimestamp= */ 2, "C1", /* eventTimestamp= */ 4), createEvent(5, "E5"), - createEvent(6, "E6")); + withUsage(createEvent(6, "E6"), 200)); Session session = Session.builder("id").events(events).build(); Event compactedEvent = createCompactedEvent(1, 5, "Summary C1-E5", 7); when(mockSummarizer.summarizeEvents(any())).thenReturn(Maybe.just(compactedEvent)); when(mockSessionService.appendEvent(any(), any())).then(i -> Single.just(i.getArgument(1))); - TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 1); + // Token usage 200 > threshold 100. + TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 1, 100); compactor.compact(session, mockSessionService).test().assertComplete(); @@ -181,14 +263,15 @@ public void compact_withMultipleCompactionEvents_respectsCompactionBoundary() { createEvent(7, "E7"), createCompactedEvent( /* startTimestamp= */ 1, /* endTimestamp= */ 3, "C2", /* eventTimestamp= */ 8), - createEvent(9, "E9")); + withUsage(createEvent(9, "E9"), 200)); Session session = Session.builder("id").events(events).build(); Event compactedEvent = createCompactedEvent(1, 4, "Summary C2-E4", 10); when(mockSummarizer.summarizeEvents(any())).thenReturn(Maybe.just(compactedEvent)); when(mockSessionService.appendEvent(any(), any())).then(i -> Single.just(i.getArgument(1))); - TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 3); + // Token usage 200 > threshold 100. + TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 3, 100); compactor.compact(session, mockSessionService).test().assertComplete(); @@ -224,6 +307,13 @@ private static String getPromptText(Event event) { .orElseThrow(); } + private Event withUsage(Event event, int tokens) { + return event.toBuilder() + .usageMetadata( + GenerateContentResponseUsageMetadata.builder().promptTokenCount(tokens).build()) + .build(); + } + private Event createCompactedEvent( long startTimestamp, long endTimestamp, String content, long eventTimestamp) { return Event.builder()