Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,12 +828,16 @@ async def stream_async(
events = self._run_loop(messages, merged_state, structured_output_model, structured_output_prompt)

async for event in events:
# Snapshot the event data before prepare() merges invocation_state
# into the dict. The callback_handler receives the full merged dict
# for backward compatibility, but stream_async() callers only see
# the serializable event fields.
event_data = event.as_dict()
event.prepare(invocation_state=merged_state)

if event.is_callback_event:
as_dict = event.as_dict()
callback_handler(**as_dict)
yield as_dict
callback_handler(**event.as_dict())
yield event_data

result = AgentResult(*event["stop"])
callback_handler(result=result)
Expand Down
32 changes: 28 additions & 4 deletions tests/strands/agent/hooks/test_agent_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,24 @@ def mock_sleep():
"request_state": {},
}

# Keys that prepare() merges from invocation_state. stream_async() no longer includes
# these in yielded events; callback_handler still receives them for backward compat.
_INVOCATION_STATE_KEYS = frozenset(any_props.keys()) | frozenset(
{
"event_loop_parent_cycle_id",
"messages",
"model",
"system_prompt",
"tool_config",
}
)


def _strip_state(events: list[dict], *user_keys: str) -> list[dict]:
"""Return events with invocation_state fields removed (matches what stream_async() yields)."""
keys_to_remove = _INVOCATION_STATE_KEYS | set(user_keys)
return [{k: v for k, v in e.items() if k not in keys_to_remove} for e in events]


@pytest.mark.asyncio
async def test_stream_e2e_success(alist):
Expand Down Expand Up @@ -317,7 +335,10 @@ async def test_stream_e2e_success(alist):
),
},
]
assert tru_events == exp_events
# stream_async() yields events without invocation_state; callback_handler receives
# the full merged dict. Verify both independently.
exp_yield_events = _strip_state(exp_events, "arg1")
assert tru_events == exp_yield_events

exp_calls = [call(**event) for event in exp_events]
act_calls = mock_callback.call_args_list
Expand Down Expand Up @@ -381,7 +402,8 @@ async def test_stream_e2e_throttle_and_redact(alist, mock_sleep):
),
},
]
assert tru_events == exp_events
exp_yield_events = _strip_state(exp_events, "arg1")
assert tru_events == exp_yield_events

exp_calls = [call(**event) for event in exp_events]
act_calls = mock_callback.call_args_list
Expand Down Expand Up @@ -459,7 +481,8 @@ async def test_stream_e2e_reasoning_redacted_content(alist):
),
},
]
assert tru_events == exp_events
exp_yield_events = _strip_state(exp_events)
assert tru_events == exp_yield_events

exp_calls = [call(**event) for event in exp_events]
act_calls = mock_callback.call_args_list
Expand Down Expand Up @@ -514,7 +537,8 @@ async def test_event_loop_cycle_text_response_throttling_early_end(
{"force_stop": True, "force_stop_reason": "ThrottlingException | ConverseStream"},
]

assert tru_events == exp_events
exp_yield_events = _strip_state(exp_events, "arg1")
assert tru_events == exp_yield_events

exp_calls = [call(**event) for event in exp_events]
act_calls = mock_callback.call_args_list
Expand Down
56 changes: 52 additions & 4 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,8 +1080,11 @@ async def test_event_loop(*args, **kwargs):
stream = agent.stream_async("test message", callback_handler=mock_callback)

tru_events = await alist(stream)

# stream_async() yields events without invocation_state merged in; invocation_state
# is only passed to the callback_handler for backward compat.
exp_events = [
{"init_event_loop": True, "callback_handler": mock_callback},
{"init_event_loop": True},
{"data": "First chunk"},
{"data": "Second chunk"},
{"complete": True, "data": "Final chunk"},
Expand All @@ -1096,8 +1099,24 @@ async def test_event_loop(*args, **kwargs):
]
assert tru_events == exp_events

exp_calls = [unittest.mock.call(**event) for event in exp_events]
mock_callback.assert_has_calls(exp_calls)
# The callback_handler receives the fully-merged dict (including invocation_state).
exp_callback_calls = [
unittest.mock.call(**{"init_event_loop": True, "callback_handler": mock_callback}),
unittest.mock.call(**{"data": "First chunk"}),
unittest.mock.call(**{"data": "Second chunk"}),
unittest.mock.call(**{"complete": True, "data": "Final chunk"}),
unittest.mock.call(
**{
"result": AgentResult(
stop_reason="stop",
message={"role": "assistant", "content": [{"text": "Response"}]},
metrics={},
state={},
)
}
),
]
mock_callback.assert_has_calls(exp_callback_calls)


@pytest.mark.asyncio
Expand Down Expand Up @@ -1196,7 +1215,7 @@ async def check_invocation_state(**kwargs):

tru_events = await alist(stream)
exp_events = [
{"init_event_loop": True, "some_value": "a_value"},
{"init_event_loop": True},
{
"result": AgentResult(
stop_reason="stop",
Expand All @@ -1211,6 +1230,35 @@ async def check_invocation_state(**kwargs):
assert mock_event_loop_cycle.call_count == 1


@pytest.mark.asyncio
async def test_stream_async_does_not_yield_invocation_state(mock_event_loop_cycle, alist):
"""stream_async() must not include invocation_state in yielded events.

Non-serializable objects passed via invocation_state were previously merged
into every ModelStreamEvent by prepare(), causing repr() serialization of
~131 KB Agent/Span objects on the wire (issue #1928).
"""

class _NotSerializable:
pass

not_serializable = _NotSerializable()

async def test_event_loop(*args, **kwargs):
yield ModelStreamEvent({"data": "hello", "delta": {"text": "hello"}})
yield EventLoopStopEvent("end_turn", {"role": "assistant", "content": []}, {}, {})

mock_event_loop_cycle.side_effect = test_event_loop

agent = Agent()
events = await alist(agent.stream_async("hi", invocation_state={"obj": not_serializable}))

stream_events = [e for e in events if "data" in e]
assert len(stream_events) == 1
assert "obj" not in stream_events[0], "invocation_state must not appear in yielded stream events"
assert stream_events[0] == {"data": "hello", "delta": {"text": "hello"}}


@pytest.mark.asyncio
async def test_stream_async_raises_exceptions(mock_event_loop_cycle):
mock_event_loop_cycle.side_effect = ValueError("Test exception")
Expand Down