diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_run_common.py b/python/packages/ag-ui/agent_framework_ag_ui/_run_common.py index cde338cbc7..0a9f4cea9c 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_run_common.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_run_common.py @@ -12,6 +12,12 @@ from ag_ui.core import ( BaseEvent, CustomEvent, + ReasoningEncryptedValueEvent, + ReasoningEndEvent, + ReasoningMessageContentEvent, + ReasoningMessageEndEvent, + ReasoningMessageStartEvent, + ReasoningStartEvent, RunFinishedEvent, StateSnapshotEvent, TextMessageContentEvent, @@ -224,27 +230,28 @@ def _emit_tool_call( return events -def _emit_tool_result( - content: Content, +def _emit_tool_result_common( + call_id: str, + raw_result: Any, flow: FlowState, predictive_handler: PredictiveStateHandler | None = None, ) -> list[BaseEvent]: - """Emit ToolCallResult events for function_result content.""" - events: list[BaseEvent] = [] + """Shared helper for emitting ToolCallEnd + ToolCallResult events and performing FlowState cleanup. - if not content.call_id: - return events + Both ``_emit_tool_result`` (standard function results) and ``_emit_mcp_tool_result`` + (MCP server tool results) delegate to this function. + """ + events: list[BaseEvent] = [] - events.append(ToolCallEndEvent(tool_call_id=content.call_id)) - flow.tool_calls_ended.add(content.call_id) + events.append(ToolCallEndEvent(tool_call_id=call_id)) + flow.tool_calls_ended.add(call_id) - raw_result = content.result if content.result is not None else "" result_content = raw_result if isinstance(raw_result, str) else json.dumps(make_json_safe(raw_result)) message_id = generate_event_id() events.append( ToolCallResultEvent( message_id=message_id, - tool_call_id=content.call_id, + tool_call_id=call_id, content=result_content, role="tool", ) @@ -254,7 +261,7 @@ def _emit_tool_result( { "id": message_id, "role": "tool", - "toolCallId": content.call_id, + "toolCallId": call_id, "content": result_content, } ) @@ -268,7 +275,7 @@ def _emit_tool_result( flow.tool_call_name = None if flow.message_id: - logger.debug("Closing text message (issue #3568 fix): message_id=%s", flow.message_id) + logger.debug("Closing text message: message_id=%s", flow.message_id) events.append(TextMessageEndEvent(message_id=flow.message_id)) flow.message_id = None flow.accumulated_text = "" @@ -276,6 +283,18 @@ def _emit_tool_result( return events +def _emit_tool_result( + content: Content, + flow: FlowState, + predictive_handler: PredictiveStateHandler | None = None, +) -> list[BaseEvent]: + """Emit ToolCallResult events for function_result content.""" + if not content.call_id: + return [] + raw_result = content.result if content.result is not None else "" + return _emit_tool_result_common(content.call_id, raw_result, flow, predictive_handler) + + def _emit_approval_request( content: Content, flow: FlowState, @@ -381,6 +400,107 @@ def _emit_oauth_consent(content: Content) -> list[BaseEvent]: ) +def _emit_mcp_tool_call(content: Content, flow: FlowState) -> list[BaseEvent]: + """Emit ToolCall start/args events for MCP server tool call content. + + MCP tool calls arrive as complete items (not streamed deltas), so we emit a + ``ToolCallStartEvent`` (and, when arguments are present, a ``ToolCallArgsEvent``) + immediately. This maps MCP-specific fields (tool_name, server_name) to the + same AG-UI ToolCall* events used by regular function calls, making MCP tool + execution visible to AG-UI consumers. Completion/end events are handled + separately by ``_emit_mcp_tool_result``. + """ + events: list[BaseEvent] = [] + + tool_call_id = content.call_id or generate_event_id() + tool_name = content.tool_name or "mcp_tool" + + display_name = tool_name + + events.append( + ToolCallStartEvent( + tool_call_id=tool_call_id, + tool_call_name=display_name, + parent_message_id=flow.message_id, + ) + ) + + # Serialize arguments + args_str = "" + if content.arguments: + args_str = ( + content.arguments if isinstance(content.arguments, str) else json.dumps(make_json_safe(content.arguments)) + ) + events.append(ToolCallArgsEvent(tool_call_id=tool_call_id, delta=args_str)) + + # Track in flow state for MESSAGES_SNAPSHOT + tool_entry = { + "id": tool_call_id, + "type": "function", + "function": {"name": display_name, "arguments": args_str}, + } + flow.pending_tool_calls.append(tool_entry) + flow.tool_calls_by_id[tool_call_id] = tool_entry + + return events + + +def _emit_mcp_tool_result( + content: Content, flow: FlowState, predictive_handler: PredictiveStateHandler | None = None +) -> list[BaseEvent]: + """Emit ToolCallResult events for MCP server tool result content. + + Delegates to the shared _emit_tool_result_common helper using content.output + (the MCP-specific result field) instead of content.result. + """ + if not content.call_id: + logger.warning("MCP tool result content missing call_id, skipping") + return [] + raw_output = content.output if content.output is not None else "" + return _emit_tool_result_common(content.call_id, raw_output, flow, predictive_handler) + + +def _emit_text_reasoning(content: Content) -> list[BaseEvent]: + """Emit AG-UI reasoning events for text_reasoning content. + + Uses the protocol-defined reasoning event types so that AG-UI consumers + such as CopilotKit can render reasoning natively. + + Only ``content.text`` is used for the visible reasoning message. If + ``content.protected_data`` is present it is emitted as a + ``ReasoningEncryptedValueEvent`` so that consumers can persist encrypted + reasoning for state continuity without conflating it with display text. + """ + text = content.text or "" + if not text and content.protected_data is None: + return [] + + message_id = content.id or generate_event_id() + + events: list[BaseEvent] = [ + ReasoningStartEvent(message_id=message_id), + ReasoningMessageStartEvent(message_id=message_id, role="assistant"), + ] + + if text: + events.append(ReasoningMessageContentEvent(message_id=message_id, delta=text)) + + events.append(ReasoningMessageEndEvent(message_id=message_id)) + + if content.protected_data is not None: + events.append( + ReasoningEncryptedValueEvent( + subtype="message", + entity_id=message_id, + encrypted_value=content.protected_data, + ) + ) + + events.append(ReasoningEndEvent(message_id=message_id)) + + return events + + def _emit_content( content: Any, flow: FlowState, @@ -402,5 +522,11 @@ def _emit_content( return _emit_usage(content) if content_type == "oauth_consent_request": return _emit_oauth_consent(content) + if content_type == "mcp_server_tool_call": + return _emit_mcp_tool_call(content, flow) + if content_type == "mcp_server_tool_result": + return _emit_mcp_tool_result(content, flow, predictive_handler) + if content_type == "text_reasoning": + return _emit_text_reasoning(content) logger.debug("Skipping unsupported content type in AG-UI emitter: %s", content_type) return [] diff --git a/python/packages/ag-ui/tests/ag_ui/test_http_round_trip.py b/python/packages/ag-ui/tests/ag_ui/test_http_round_trip.py index 7e4712535c..5a86a6ff59 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_http_round_trip.py +++ b/python/packages/ag-ui/tests/ag_ui/test_http_round_trip.py @@ -213,3 +213,134 @@ def test_sse_response_headers() -> None: assert response.headers["content-type"] == "text/event-stream; charset=utf-8" assert response.headers.get("cache-control") == "no-cache" + + +# ── MCP tool call SSE round-trip ── + + +def test_mcp_tool_call_sse_round_trip() -> None: + """MCP tool call + result events survive SSE encoding/parsing round-trip.""" + app = _build_app_with_agent( + [ + AgentResponseUpdate( + contents=[ + Content.from_mcp_server_tool_call( + call_id="mcp-1", + tool_name="search", + server_name="brave", + arguments={"query": "weather"}, + ) + ], + role="assistant", + ), + AgentResponseUpdate( + contents=[ + Content.from_mcp_server_tool_result( + call_id="mcp-1", + output={"results": ["sunny"]}, + ) + ], + role="assistant", + ), + AgentResponseUpdate( + contents=[Content.from_text(text="It's sunny!")], + role="assistant", + ), + ] + ) + client = TestClient(app) + response = client.post("/", json=USER_PAYLOAD) + + assert response.status_code == 200 + stream = parse_sse_to_event_stream(response.content) + stream.assert_bookends() + stream.assert_tool_calls_balanced() + stream.assert_text_messages_balanced() + stream.assert_no_run_error() + + # Verify MCP tool call details survive SSE encoding + start = stream.first("TOOL_CALL_START") + assert start.tool_call_name == "search" + assert start.tool_call_id == "mcp-1" + + # Verify the result came through + result = stream.first("TOOL_CALL_RESULT") + assert "sunny" in result.content + + +# ── Text reasoning SSE round-trip ── + + +def test_text_reasoning_sse_round_trip() -> None: + """Text reasoning events survive SSE encoding/parsing round-trip.""" + app = _build_app_with_agent( + [ + AgentResponseUpdate( + contents=[ + Content.from_text_reasoning( + id="reason-1", + text="The user wants weather info, I should use a tool.", + ) + ], + role="assistant", + ), + AgentResponseUpdate( + contents=[Content.from_text(text="Let me check the weather.")], + role="assistant", + ), + ] + ) + client = TestClient(app) + response = client.post("/", json=USER_PAYLOAD) + + assert response.status_code == 200 + stream = parse_sse_to_event_stream(response.content) + stream.assert_bookends() + stream.assert_text_messages_balanced() + stream.assert_no_run_error() + stream.assert_has_type("REASONING_START") + stream.assert_has_type("REASONING_MESSAGE_CONTENT") + stream.assert_has_type("REASONING_END") + + # Verify reasoning content survives SSE encoding + raw_events = parse_sse_response(response.content) + reasoning_content = [e for e in raw_events if e["type"] == "REASONING_MESSAGE_CONTENT"] + assert len(reasoning_content) == 1 + assert "weather" in reasoning_content[0]["delta"] + + +def test_text_reasoning_with_encrypted_value_sse_round_trip() -> None: + """Reasoning with protected_data emits ReasoningEncryptedValue through SSE.""" + app = _build_app_with_agent( + [ + AgentResponseUpdate( + contents=[ + Content.from_text_reasoning( + id="reason-enc", + text="visible reasoning", + protected_data="encrypted-payload-abc123", + ) + ], + role="assistant", + ), + AgentResponseUpdate( + contents=[Content.from_text(text="Done.")], + role="assistant", + ), + ] + ) + client = TestClient(app) + response = client.post("/", json=USER_PAYLOAD) + + assert response.status_code == 200 + stream = parse_sse_to_event_stream(response.content) + stream.assert_bookends() + stream.assert_no_run_error() + stream.assert_has_type("REASONING_ENCRYPTED_VALUE") + + raw_events = parse_sse_response(response.content) + encrypted = [e for e in raw_events if e["type"] == "REASONING_ENCRYPTED_VALUE"] + assert len(encrypted) == 1 + assert encrypted[0]["encryptedValue"] == "encrypted-payload-abc123" + assert encrypted[0]["entityId"] == "reason-enc" + assert encrypted[0]["subtype"] == "message" diff --git a/python/packages/ag-ui/tests/ag_ui/test_run.py b/python/packages/ag-ui/tests/ag_ui/test_run.py index 5a0cd1605c..ae8c5e85b0 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_run.py +++ b/python/packages/ag-ui/tests/ag_ui/test_run.py @@ -5,6 +5,12 @@ import pytest from ag_ui.core import ( CustomEvent, + ReasoningEncryptedValueEvent, + ReasoningEndEvent, + ReasoningMessageContentEvent, + ReasoningMessageEndEvent, + ReasoningMessageStartEvent, + ReasoningStartEvent, TextMessageEndEvent, TextMessageStartEvent, ToolCallArgsEvent, @@ -25,7 +31,10 @@ _build_run_finished_event, _emit_approval_request, _emit_content, + _emit_mcp_tool_call, + _emit_mcp_tool_result, _emit_text, + _emit_text_reasoning, _emit_tool_call, _emit_tool_result, _extract_resume_payload, @@ -991,3 +1000,349 @@ def test_emit_oauth_consent_request_no_link(): events = _emit_content(content, flow) assert len(events) == 0 + + +# ============================================================================ +# Tests for MCP tool call, MCP tool result, and text reasoning event emission +# ============================================================================ + + +class TestEmitMcpToolCall: + """Tests for _emit_mcp_tool_call function.""" + + def test_produces_start_and_args_events(self): + """MCP tool call emits ToolCallStart + ToolCallArgs events.""" + flow = FlowState() + content = Content.from_mcp_server_tool_call( + call_id="mcp_call_1", + tool_name="search", + server_name="brave", + arguments={"query": "weather"}, + ) + + events = _emit_mcp_tool_call(content, flow) + + assert len(events) == 2 + assert events[0].type == "TOOL_CALL_START" + assert events[0].tool_call_id == "mcp_call_1" + assert events[0].tool_call_name == "search" + assert events[1].type == "TOOL_CALL_ARGS" + assert events[1].tool_call_id == "mcp_call_1" + assert "weather" in events[1].delta + + def test_tracks_in_flow_state(self): + """MCP tool call is tracked in flow.pending_tool_calls and tool_calls_by_id.""" + flow = FlowState() + content = Content.from_mcp_server_tool_call( + call_id="mcp_call_2", + tool_name="get_file", + arguments='{"path": "/tmp/test.txt"}', + ) + + _emit_mcp_tool_call(content, flow) + + assert len(flow.pending_tool_calls) == 1 + assert flow.pending_tool_calls[0]["id"] == "mcp_call_2" + assert "mcp_call_2" in flow.tool_calls_by_id + assert flow.tool_calls_by_id["mcp_call_2"]["function"]["name"] == "get_file" + assert flow.tool_calls_by_id["mcp_call_2"]["function"]["arguments"] == '{"path": "/tmp/test.txt"}' + + def test_no_server_name_uses_tool_name_only(self): + """Without server_name, display name is just tool_name.""" + flow = FlowState() + content = Content.from_mcp_server_tool_call( + call_id="mcp_call_3", + tool_name="list_files", + ) + + events = _emit_mcp_tool_call(content, flow) + + assert events[0].tool_call_name == "list_files" + + def test_no_arguments_skips_args_event(self): + """No arguments produces only ToolCallStart, no ToolCallArgs.""" + flow = FlowState() + content = Content.from_mcp_server_tool_call( + call_id="mcp_call_4", + tool_name="ping", + ) + + events = _emit_mcp_tool_call(content, flow) + + assert len(events) == 1 + assert events[0].type == "TOOL_CALL_START" + + def test_generates_id_when_missing(self): + """A tool_call_id is generated when call_id is None.""" + flow = FlowState() + content = Content(type="mcp_server_tool_call", tool_name="test_tool") + + events = _emit_mcp_tool_call(content, flow) + + assert len(events) >= 1 + assert events[0].tool_call_id is not None + assert events[0].tool_call_id != "" + assert events[0].tool_call_name == "test_tool" + + def test_missing_tool_name_falls_back_to_mcp_tool(self): + """When tool_name is None, the fallback 'mcp_tool' is used.""" + flow = FlowState() + content = Content(type="mcp_server_tool_call") + + events = _emit_mcp_tool_call(content, flow) + + assert len(events) >= 1 + assert events[0].tool_call_name == "mcp_tool" + + +class TestEmitMcpToolResult: + """Tests for _emit_mcp_tool_result function.""" + + def test_produces_end_and_result_events(self): + """MCP tool result emits ToolCallEnd + ToolCallResult events.""" + flow = FlowState() + content = Content.from_mcp_server_tool_result( + call_id="mcp_call_1", + output={"results": [{"title": "Weather", "url": "https://example.com"}]}, + ) + + events = _emit_mcp_tool_result(content, flow) + + assert len(events) == 2 + assert events[0].type == "TOOL_CALL_END" + assert events[0].tool_call_id == "mcp_call_1" + assert events[1].type == "TOOL_CALL_RESULT" + assert events[1].tool_call_id == "mcp_call_1" + assert "Weather" in events[1].content + + def test_tracks_in_flow_state(self): + """MCP tool result is tracked in flow.tool_results and tool_calls_ended.""" + flow = FlowState() + content = Content.from_mcp_server_tool_result( + call_id="mcp_call_5", + output="Success", + ) + + _emit_mcp_tool_result(content, flow) + + assert "mcp_call_5" in flow.tool_calls_ended + assert len(flow.tool_results) == 1 + assert flow.tool_results[0]["toolCallId"] == "mcp_call_5" + assert flow.tool_results[0]["content"] == "Success" + + def test_no_call_id_returns_empty(self): + """Missing call_id returns empty events list with a warning.""" + flow = FlowState() + content = Content(type="mcp_server_tool_result", output="data") + + events = _emit_mcp_tool_result(content, flow) + + assert events == [] + + def test_serializes_non_string_output(self): + """Non-string output is serialized to JSON.""" + flow = FlowState() + content = Content.from_mcp_server_tool_result( + call_id="mcp_call_6", + output={"key": "value", "count": 42}, + ) + + events = _emit_mcp_tool_result(content, flow) + + result_event = events[1] + assert isinstance(result_event.content, str) + assert '"key": "value"' in result_event.content + + def test_output_none_falls_back_to_empty_string(self): + """When output is None (default), the result content is an empty string.""" + flow = FlowState() + content = Content(type="mcp_server_tool_result", call_id="mcp_call_none") + + events = _emit_mcp_tool_result(content, flow) + + assert len(events) == 2 + assert events[1].type == "TOOL_CALL_RESULT" + assert events[1].content == "" + + def test_resets_flow_state_like_emit_tool_result(self): + """MCP tool result performs same FlowState cleanup as _emit_tool_result.""" + flow = FlowState() + flow.tool_call_id = "mcp_call_7" + flow.tool_call_name = "brave/search" + flow.message_id = "open-msg-456" + flow.accumulated_text = "Let me search for that..." + + content = Content.from_mcp_server_tool_result( + call_id="mcp_call_7", + output="search results", + ) + + events = _emit_mcp_tool_result(content, flow) + + assert flow.tool_call_id is None + assert flow.tool_call_name is None + assert flow.message_id is None + assert flow.accumulated_text == "" + + text_end_events = [e for e in events if isinstance(e, TextMessageEndEvent)] + assert len(text_end_events) == 1 + assert text_end_events[0].message_id == "open-msg-456" + + def test_no_open_message_skips_text_end(self): + """MCP tool result without open text message skips TextMessageEndEvent.""" + flow = FlowState() + flow.message_id = None + + content = Content.from_mcp_server_tool_result( + call_id="mcp_call_8", + output="result", + ) + + events = _emit_mcp_tool_result(content, flow) + + text_end_events = [e for e in events if isinstance(e, TextMessageEndEvent)] + assert len(text_end_events) == 0 + + def test_predictive_handler_emits_state_snapshot(self): + """MCP tool result applies pending updates and emits StateSnapshotEvent when predictive_handler is set.""" + from unittest.mock import MagicMock + + from ag_ui.core import StateSnapshotEvent + + flow = FlowState() + flow.current_state = {"doc": "hello"} + content = Content.from_mcp_server_tool_result( + call_id="mcp_call_9", + output="done", + ) + + handler = MagicMock() + events = _emit_mcp_tool_result(content, flow, predictive_handler=handler) + + handler.apply_pending_updates.assert_called_once() + snapshot_events = [e for e in events if isinstance(e, StateSnapshotEvent)] + assert len(snapshot_events) == 1 + assert snapshot_events[0].snapshot == {"doc": "hello"} + + +class TestEmitTextReasoning: + """Tests for _emit_text_reasoning function.""" + + def test_produces_reasoning_events(self): + """Text reasoning emits the full reasoning event sequence.""" + content = Content.from_text_reasoning( + id="reason_1", + text="The user is asking about weather, so I should call the weather tool.", + ) + + events = _emit_text_reasoning(content) + + assert len(events) == 5 + assert isinstance(events[0], ReasoningStartEvent) + assert events[0].message_id == "reason_1" + assert isinstance(events[1], ReasoningMessageStartEvent) + assert events[1].message_id == "reason_1" + assert events[1].role == "assistant" + assert isinstance(events[2], ReasoningMessageContentEvent) + assert events[2].message_id == "reason_1" + assert events[2].delta == "The user is asking about weather, so I should call the weather tool." + assert isinstance(events[3], ReasoningMessageEndEvent) + assert events[3].message_id == "reason_1" + assert isinstance(events[4], ReasoningEndEvent) + assert events[4].message_id == "reason_1" + + def test_protected_data_emits_encrypted_value_event(self): + """protected_data is emitted as a ReasoningEncryptedValueEvent.""" + content = Content.from_text_reasoning( + id="reason_2", + text="visible reasoning", + protected_data="encrypted metadata", + ) + + events = _emit_text_reasoning(content) + + encrypted_events = [e for e in events if isinstance(e, ReasoningEncryptedValueEvent)] + assert len(encrypted_events) == 1 + assert encrypted_events[0].subtype == "message" + assert encrypted_events[0].entity_id == "reason_2" + assert encrypted_events[0].encrypted_value == "encrypted metadata" + + def test_protected_data_only_emits_event(self): + """Content with only protected_data (no text) still emits reasoning events.""" + content = Content.from_text_reasoning( + protected_data="encrypted reasoning content", + ) + + events = _emit_text_reasoning(content) + + # Should have start, msg_start, msg_end, encrypted_value, end (no content event) + assert len(events) == 5 + assert isinstance(events[0], ReasoningStartEvent) + assert isinstance(events[1], ReasoningMessageStartEvent) + assert isinstance(events[2], ReasoningMessageEndEvent) + assert isinstance(events[3], ReasoningEncryptedValueEvent) + assert events[3].encrypted_value == "encrypted reasoning content" + assert isinstance(events[4], ReasoningEndEvent) + + def test_empty_text_and_no_protected_data_returns_empty(self): + """Empty text and no protected_data returns no events.""" + content = Content.from_text_reasoning() + + events = _emit_text_reasoning(content) + + assert events == [] + + def test_generates_message_id_when_missing(self): + """When id is None, a message_id is generated.""" + content = Content.from_text_reasoning(text="thinking...") + + events = _emit_text_reasoning(content) + + assert len(events) == 5 + assert events[0].message_id is not None + assert events[0].message_id != "" + # All events share the same message_id + assert events[1].message_id == events[0].message_id + + +class TestEmitContentMcpRouting: + """Tests that _emit_content correctly routes MCP and reasoning types.""" + + def test_routes_mcp_server_tool_call(self): + """_emit_content dispatches mcp_server_tool_call to _emit_mcp_tool_call.""" + flow = FlowState() + content = Content.from_mcp_server_tool_call( + call_id="route_test_1", + tool_name="test_tool", + server_name="test_server", + ) + + events = _emit_content(content, flow) + + assert len(events) >= 1 + assert events[0].type == "TOOL_CALL_START" + assert events[0].tool_call_name == "test_tool" + + def test_routes_mcp_server_tool_result(self): + """_emit_content dispatches mcp_server_tool_result to _emit_mcp_tool_result.""" + flow = FlowState() + content = Content.from_mcp_server_tool_result( + call_id="route_test_2", + output="result data", + ) + + events = _emit_content(content, flow) + + assert len(events) == 2 + assert events[0].type == "TOOL_CALL_END" + assert events[1].type == "TOOL_CALL_RESULT" + + def test_routes_text_reasoning(self): + """_emit_content dispatches text_reasoning to _emit_text_reasoning.""" + flow = FlowState() + content = Content.from_text_reasoning(text="I need to think about this...") + + events = _emit_content(content, flow) + + assert len(events) == 5 + assert isinstance(events[0], ReasoningStartEvent)