diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 64343302fe..508c00d73d 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -926,6 +926,7 @@ async def run_live( session_id: Optional[str] = None, live_request_queue: LiveRequestQueue, run_config: Optional[RunConfig] = None, + state_delta: Optional[dict[str, Any]] = None, session: Optional[Session] = None, ) -> AsyncGenerator[Event, None]: """Runs the agent in live mode (experimental feature). @@ -966,6 +967,7 @@ async def run_live( None. live_request_queue: The queue for live requests. run_config: The run config for the agent. + state_delta: Optional state changes to apply to the session. session: The session to use. This parameter is deprecated, please use `user_id` and `session_id` instead. @@ -1009,6 +1011,16 @@ async def run_live( run_config=run_config, ) + # Apply state_delta if provided + if state_delta: + state_event = Event( + invocation_id=invocation_context.invocation_id, + author='user', + actions=EventActions(state_delta=state_delta), + ) + _apply_run_config_custom_metadata(state_event, run_config) + await self.session_service.append_event(session=session, event=state_event) + root_agent = self.agent invocation_context.agent = self._find_agent_to_run(session, root_agent) diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index c876bff53a..a5f5833d74 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -442,6 +442,50 @@ async def _run_live_impl( assert "non_streaming_tool" not in active_tools +@pytest.mark.asyncio +async def test_run_live_state_delta_applied_to_session(): + """run_live should apply state_delta to the session at the start.""" + session_service = InMemorySessionService() + artifact_service = InMemoryArtifactService() + runner = Runner( + app_name="run_live_app", + agent=MockLiveAgent("live_agent"), + session_service=session_service, + artifact_service=artifact_service, + auto_create_session=True, + ) + + live_queue = LiveRequestQueue() + state_delta = {"useCase": "voice_assistant"} + + agen = runner.run_live( + user_id="user", + session_id="test_session", + live_request_queue=live_queue, + state_delta=state_delta, + ) + + event = await agen.__anext__() + await agen.aclose() + + assert event.author == "live_agent" + + # Verify state_delta was applied to the session + session = await session_service.get_session( + app_name="run_live_app", user_id="user", session_id="test_session" + ) + assert session is not None + assert session.state.get("useCase") == "voice_assistant" + + # Verify the state_delta event was appended to the session + state_delta_events = [ + e for e in session.events + if e.actions.state_delta and e.author == "user" + ] + assert len(state_delta_events) == 1 + assert state_delta_events[0].actions.state_delta == state_delta + + @pytest.mark.asyncio async def test_runner_allows_nested_agent_directories(tmp_path, monkeypatch): project_root = tmp_path / "workspace"