diff --git a/src/google/adk/agents/live_request_queue.py b/src/google/adk/agents/live_request_queue.py index 9b698c81d6..f419d136e1 100644 --- a/src/google/adk/agents/live_request_queue.py +++ b/src/google/adk/agents/live_request_queue.py @@ -24,34 +24,26 @@ class LiveRequest(BaseModel): - """Request send to live agents.""" + """Request send to live agents. + + When multiple fields are set, they are processed by priority (highest first): + activity_start > activity_end > audio_stream_end > blob > content. + """ model_config = ConfigDict(ser_json_bytes='base64', val_json_bytes='base64') """The pydantic model config.""" content: Optional[types.Content] = None - """If set, send the content to the model in turn-by-turn mode. - - When multiple fields are set, they are processed by priority (highest first): - activity_start > activity_end > blob > content. - """ + """If set, send the content to the model in turn-by-turn mode.""" blob: Optional[types.Blob] = None - """If set, send the blob to the model in realtime mode. - - When multiple fields are set, they are processed by priority (highest first): - activity_start > activity_end > blob > content. - """ + """If set, send the blob to the model in realtime mode.""" activity_start: Optional[types.ActivityStart] = None - """If set, signal the start of user activity to the model. - - When multiple fields are set, they are processed by priority (highest first): - activity_start > activity_end > blob > content. - """ + """If set, signal the start of user activity to the model.""" activity_end: Optional[types.ActivityEnd] = None - """If set, signal the end of user activity to the model. - - When multiple fields are set, they are processed by priority (highest first): - activity_start > activity_end > blob > content. + """If set, signal the end of user activity to the model.""" + audio_stream_end: bool = False + """If set, signal the end of the audio stream to the model. + This is only used when Voice Activity Detection is enabled. """ close: bool = False """If set, close the queue. queue.shutdown() is only supported in Python 3.13+.""" @@ -80,6 +72,10 @@ def send_activity_end(self): """Sends an activity end signal to mark the end of user input.""" self._queue.put_nowait(LiveRequest(activity_end=types.ActivityEnd())) + def send_audio_stream_end(self): + """Sends an audio stream end signal to force flush audio.""" + self._queue.put_nowait(LiveRequest(audio_stream_end=True)) + def send(self, req: LiveRequest): self._queue.put_nowait(req) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index f1c1cce813..d3818e5a63 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -308,6 +308,10 @@ async def _send_to_model( await llm_connection.send_realtime(types.ActivityStart()) elif live_request.activity_end: await llm_connection.send_realtime(types.ActivityEnd()) + elif live_request.audio_stream_end: + await llm_connection.send_realtime( + types.LiveClientRealtimeInput(audio_stream_end=True) + ) elif live_request.blob: # Cache input audio chunks before flushing self.audio_cache_manager.cache_audio( diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 63606b21b0..215b409e6a 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -28,7 +28,12 @@ logger = logging.getLogger('google_adk.' + __name__) -RealtimeInput = Union[types.Blob, types.ActivityStart, types.ActivityEnd] +RealtimeInput = Union[ + types.Blob, + types.ActivityStart, + types.ActivityEnd, + types.LiveClientRealtimeInput, +] from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -136,6 +141,13 @@ async def send_realtime(self, input: RealtimeInput): elif isinstance(input, types.ActivityEnd): logger.debug('Sending LLM activity end signal.') await self._gemini_session.send_realtime_input(activity_end=input) + + elif isinstance(input, types.LiveClientRealtimeInput): + if input.audio_stream_end: + logger.debug('Sending LLM audio stream end signal.') + await self._gemini_session.send_realtime_input(audio_stream_end=True) + else: + logger.warning('Unary LiveClientRealtimeInput not fully supported yet.') else: raise ValueError('Unsupported input type: %s' % type(input)) diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index d065661c69..daba14f07e 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -69,6 +69,22 @@ async def test_send_realtime_default_behavior( mock_gemini_session.send.assert_not_called() +@pytest.mark.asyncio +async def test_send_realtime_audiostreamend( + gemini_connection, mock_gemini_session +): + """Test send_realtime with LiveClientRealtimeInput(audio_stream_end=True).""" + input_signal = types.LiveClientRealtimeInput(audio_stream_end=True) + await gemini_connection.send_realtime(input_signal) + + # Should call send_realtime_input with audio_stream_end=True + mock_gemini_session.send_realtime_input.assert_called_once_with( + audio_stream_end=True + ) + # Should not call .send function + mock_gemini_session.send.assert_not_called() + + @pytest.mark.asyncio async def test_send_history(gemini_connection, mock_gemini_session): """Test send_history method."""