diff --git a/python/packages/core/agent_framework/orchestrations/__init__.py b/python/packages/core/agent_framework/orchestrations/__init__.py index fa3561f22f..67b92f3ad6 100644 --- a/python/packages/core/agent_framework/orchestrations/__init__.py +++ b/python/packages/core/agent_framework/orchestrations/__init__.py @@ -38,8 +38,6 @@ # Orchestration helpers "AgentRequestInfoResponse", "OrchestrationState", - "clean_conversation_for_handoff", - "create_completion_message", # Group Chat "AgentBasedGroupChatOrchestrator", "AgentOrchestrationOutput", diff --git a/python/packages/orchestrations/agent_framework_orchestrations/__init__.py b/python/packages/orchestrations/agent_framework_orchestrations/__init__.py index d1acb7af53..d4de03e6e7 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/__init__.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/__init__.py @@ -61,8 +61,7 @@ StandardMagenticManager, ) from ._orchestration_request_info import AgentRequestInfoResponse -from ._orchestration_state import OrchestrationState -from ._orchestrator_helpers import clean_conversation_for_handoff, create_completion_message +from ._orchestration_shared import OrchestrationState from ._sequential import SequentialBuilder __all__ = [ @@ -105,6 +104,4 @@ "StandardMagenticManager", "TerminationCondition", "__version__", - "clean_conversation_for_handoff", - "create_completion_message", ] diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_base_group_chat_orchestrator.py b/python/packages/orchestrations/agent_framework_orchestrations/_base_group_chat_orchestrator.py index f01f3700f7..422d3b11b2 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_base_group_chat_orchestrator.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_base_group_chat_orchestrator.py @@ -20,6 +20,7 @@ from typing_extensions import Never from ._orchestration_request_info import AgentApprovalExecutor +from ._orchestration_shared import OrchestrationOutput if sys.version_info >= (3, 12): from typing import override # type: ignore # pragma: no cover @@ -351,7 +352,7 @@ async def _check_termination(self) -> bool: result = await result return result - async def _check_terminate_and_yield(self, ctx: WorkflowContext[Never, list[Message]]) -> bool: + async def _check_terminate_and_yield(self, ctx: WorkflowContext[Never, OrchestrationOutput]) -> bool: """Check termination conditions and yield completion if met. Args: @@ -363,7 +364,7 @@ async def _check_terminate_and_yield(self, ctx: WorkflowContext[Never, list[Mess terminate = await self._check_termination() if terminate: self._append_messages([self._create_completion_message(self.TERMINATION_CONDITION_MET_MESSAGE)]) - await ctx.yield_output(self._full_conversation) + await ctx.yield_output(OrchestrationOutput(messages=self._full_conversation)) return True return False @@ -490,7 +491,7 @@ def _check_round_limit(self) -> bool: return False - async def _check_round_limit_and_yield(self, ctx: WorkflowContext[Never, list[Message]]) -> bool: + async def _check_round_limit_and_yield(self, ctx: WorkflowContext[Never, OrchestrationOutput]) -> bool: """Check round limit and yield completion if reached. Args: @@ -502,7 +503,7 @@ async def _check_round_limit_and_yield(self, ctx: WorkflowContext[Never, list[Me reach_max_rounds = self._check_round_limit() if reach_max_rounds: self._append_messages([self._create_completion_message(self.MAX_ROUNDS_MET_MESSAGE)]) - await ctx.yield_output(self._full_conversation) + await ctx.yield_output(OrchestrationOutput(messages=self._full_conversation)) return True return False @@ -521,7 +522,7 @@ async def on_checkpoint_save(self) -> dict[str, Any]: Returns: Serialized state dict """ - from ._orchestration_state import OrchestrationState + from ._orchestration_shared import OrchestrationState state = OrchestrationState( conversation=list(self._full_conversation), @@ -551,7 +552,7 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: Args: state: Serialized state dict """ - from ._orchestration_state import OrchestrationState + from ._orchestration_shared import OrchestrationState orch_state = OrchestrationState.from_dict(state) self._full_conversation = list(orch_state.conversation) diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_concurrent.py b/python/packages/orchestrations/agent_framework_orchestrations/_concurrent.py index 062e87806c..aa75c54116 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_concurrent.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_concurrent.py @@ -1,10 +1,9 @@ # Copyright (c) Microsoft. All rights reserved. -import asyncio import inspect import logging from collections.abc import Callable, Sequence -from typing import Any +from typing import Any, Awaitable from agent_framework import Message, SupportsAgentRun from agent_framework._workflows._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse @@ -18,6 +17,7 @@ from typing_extensions import Never from ._orchestration_request_info import AgentApprovalExecutor +from ._orchestration_shared import OrchestrationOutput logger = logging.getLogger(__name__) @@ -32,13 +32,11 @@ - Participants can be provided as SupportsAgentRun or Executor instances via `participants=[...]`. - A custom aggregator can be provided as: - an Executor instance (it should handle list[AgentExecutorResponse], - yield output), or + yield OrchestrationOutput), or - a callback function with signature: - def cb(results: list[AgentExecutorResponse]) -> Any | None - def cb(results: list[AgentExecutorResponse], ctx: WorkflowContext) -> Any | None - The callback is wrapped in _CallbackAggregator. - If the callback returns a non-None value, _CallbackAggregator yields that as output. - If it returns None, the callback may have already yielded an output via ctx, so no further action is taken. + def cb(results: list[AgentExecutorResponse]) -> list[Message] + The callback is wrapped in _CallbackAggregator, which wraps the returned + list[Message] in an OrchestrationOutput and yields it. """ @@ -82,7 +80,9 @@ class _AggregateAgentConversations(Executor): """ @handler - async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, list[Message]]) -> None: + async def aggregate( + self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, OrchestrationOutput] + ) -> None: if not results: logger.error("Concurrent aggregator received empty results list") raise ValueError("Aggregation failed: no results provided") @@ -137,47 +137,43 @@ def _is_role(msg: Any, role: str) -> bool: logger.warning("No user prompt found in any conversation; emitting assistants only") output.extend(assistant_replies) - await ctx.yield_output(output) + await ctx.yield_output(OrchestrationOutput(messages=output)) class _CallbackAggregator(Executor): """Wraps a Python callback as an aggregator. - Accepts either an async or sync callback with one of the signatures: - - (results: list[AgentExecutorResponse]) -> Any | None - - (results: list[AgentExecutorResponse], ctx: WorkflowContext[Any]) -> Any | None + Accepts either an async or sync callback with the following signature: + - (results: list[AgentExecutorResponse]) -> list[Message] + + The returned list[Message] is automatically wrapped in an OrchestrationOutput. Notes: - Async callbacks are awaited directly. - Sync callbacks are executed via asyncio.to_thread to avoid blocking the event loop. - - If the callback returns a non-None value, it is yielded as an output. + - If the callback returns a non-None value, it is wrapped in OrchestrationOutput and yielded. """ - def __init__(self, callback: Callable[..., Any], id: str | None = None) -> None: + def __init__( + self, + callback: Callable[[list[AgentExecutorResponse]], Awaitable[list[Message]] | list[Message]], + id: str | None = None, + ) -> None: derived_id = getattr(callback, "__name__", "") or "" if not derived_id or derived_id == "": derived_id = f"{type(self).__name__}_unnamed" super().__init__(id or derived_id) self._callback = callback - self._param_count = len(inspect.signature(callback).parameters) @handler - async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, Any]) -> None: - # Call according to provided signature, always non-blocking for sync callbacks - if self._param_count >= 2: - if inspect.iscoroutinefunction(self._callback): - ret = await self._callback(results, ctx) # type: ignore[misc] - else: - ret = await asyncio.to_thread(self._callback, results, ctx) - else: - if inspect.iscoroutinefunction(self._callback): - ret = await self._callback(results) # type: ignore[misc] - else: - ret = await asyncio.to_thread(self._callback, results) + async def aggregate( + self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, OrchestrationOutput] + ) -> None: + ret = self._callback(results) + if inspect.isawaitable(ret): + ret = await ret - # If the callback returned a value, finalize the workflow with it - if ret is not None: - await ctx.yield_output(ret) + await ctx.yield_output(OrchestrationOutput(messages=list(ret))) class ConcurrentBuilder: @@ -193,7 +189,7 @@ class ConcurrentBuilder: from agent_framework_orchestrations import ConcurrentBuilder - # Minimal: use default aggregator (returns list[Message]) + # Minimal: use default aggregator (returns OrchestrationOutput) workflow = ConcurrentBuilder(participants=[agent1, agent2, agent3]).build() @@ -265,17 +261,14 @@ def _set_participants(self, participants: Sequence[SupportsAgentRun | Executor]) def with_aggregator( self, - aggregator: Executor - | Callable[[list[AgentExecutorResponse]], Any] - | Callable[[list[AgentExecutorResponse], WorkflowContext[Never, Any]], Any], + aggregator: Executor | Callable[[list[AgentExecutorResponse]], Awaitable[list[Message]] | list[Message]], ) -> "ConcurrentBuilder": r"""Override the default aggregator with an executor or a callback. - Executor: must handle `list[AgentExecutorResponse]` and yield output using `ctx.yield_output(...)` - - Callback: sync or async callable with one of the signatures: - `(results: list[AgentExecutorResponse]) -> Any | None` or - `(results: list[AgentExecutorResponse], ctx: WorkflowContext) -> Any | None`. - If the callback returns a non-None value, it becomes the workflow's output. + - Callback: sync or async callable with the signature: + `(results: list[AgentExecutorResponse]) -> list[Message]`. + The returned list[Message] is automatically wrapped in an OrchestrationOutput. Args: aggregator: Executor instance, or callback function @@ -287,23 +280,15 @@ def with_aggregator( class CustomAggregator(Executor): @handler async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext) -> None: - await ctx.yield_output(" | ".join(r.agent_response.messages[-1].text for r in results)) + await ctx.yield_output(OrchestrationOutput(messages=[...])) wf = ConcurrentBuilder(participants=[a1, a2, a3]).with_aggregator(CustomAggregator()).build() - # Callback-based aggregator (string result) - async def summarize(results: list[AgentExecutorResponse]) -> str: - return " | ".join(r.agent_response.messages[-1].text for r in results) - - - wf = ConcurrentBuilder(participants=[a1, a2, a3]).with_aggregator(summarize).build() - - - # Callback-based aggregator (yield result) - async def summarize(results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, str]) -> None: - await ctx.yield_output(" | ".join(r.agent_response.messages[-1].text for r in results)) + # Callback-based aggregator (returns list[Message], wrapped in OrchestrationOutput) + async def summarize(results: list[AgentExecutorResponse]) -> list[Message]: + return [r.agent_response.messages[-1] for r in results] wf = ConcurrentBuilder(participants=[a1, a2, a3]).with_aggregator(summarize).build() diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_group_chat.py b/python/packages/orchestrations/agent_framework_orchestrations/_group_chat.py index a99e221409..98d32b85d1 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_group_chat.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_group_chat.py @@ -50,7 +50,7 @@ TerminationCondition, ) from ._orchestration_request_info import AgentApprovalExecutor -from ._orchestrator_helpers import clean_conversation_for_handoff +from ._orchestration_shared import OrchestrationOutput, filter_tool_contents if sys.version_info >= (3, 12): from typing import override # type: ignore # pragma: no cover @@ -169,7 +169,7 @@ async def _handle_messages( """Initialize orchestrator state and start the conversation loop.""" self._append_messages(messages) # Termination condition will also be applied to the input messages - if await self._check_terminate_and_yield(cast(WorkflowContext[Never, list[Message]], ctx)): + if await self._check_terminate_and_yield(cast(WorkflowContext[Never, OrchestrationOutput], ctx)): return next_speaker = await self._get_next_speaker() @@ -195,12 +195,12 @@ async def _handle_response( """Handle a participant response.""" messages = self._process_participant_response(response) # Remove tool-related content to prevent API errors from empty messages - messages = clean_conversation_for_handoff(messages) + messages = filter_tool_contents(messages) self._append_messages(messages) - if await self._check_terminate_and_yield(cast(WorkflowContext[Never, list[Message]], ctx)): + if await self._check_terminate_and_yield(cast(WorkflowContext[Never, OrchestrationOutput], ctx)): return - if await self._check_round_limit_and_yield(cast(WorkflowContext[Never, list[Message]], ctx)): + if await self._check_round_limit_and_yield(cast(WorkflowContext[Never, OrchestrationOutput], ctx)): return next_speaker = await self._get_next_speaker() @@ -332,13 +332,13 @@ async def _handle_messages( """Initialize orchestrator state and start the conversation loop.""" self._append_messages(messages) # Termination condition will also be applied to the input messages - if await self._check_terminate_and_yield(cast(WorkflowContext[Never, list[Message]], ctx)): + if await self._check_terminate_and_yield(cast(WorkflowContext[Never, OrchestrationOutput], ctx)): return agent_orchestration_output = await self._invoke_agent() if await self._check_agent_terminate_and_yield( agent_orchestration_output, - cast(WorkflowContext[Never, list[Message]], ctx), + cast(WorkflowContext[Never, OrchestrationOutput], ctx), ): return @@ -364,17 +364,17 @@ async def _handle_response( """Handle a participant response.""" messages = self._process_participant_response(response) # Remove tool-related content to prevent API errors from empty messages - messages = clean_conversation_for_handoff(messages) + messages = filter_tool_contents(messages) self._append_messages(messages) - if await self._check_terminate_and_yield(cast(WorkflowContext[Never, list[Message]], ctx)): + if await self._check_terminate_and_yield(cast(WorkflowContext[Never, OrchestrationOutput], ctx)): return - if await self._check_round_limit_and_yield(cast(WorkflowContext[Never, list[Message]], ctx)): + if await self._check_round_limit_and_yield(cast(WorkflowContext[Never, OrchestrationOutput], ctx)): return agent_orchestration_output = await self._invoke_agent() if await self._check_agent_terminate_and_yield( agent_orchestration_output, - cast(WorkflowContext[Never, list[Message]], ctx), + cast(WorkflowContext[Never, OrchestrationOutput], ctx), ): return @@ -522,7 +522,7 @@ async def _invoke_agent_helper(conversation: list[Message]) -> AgentOrchestratio async def _check_agent_terminate_and_yield( self, agent_orchestration_output: AgentOrchestrationOutput, - ctx: WorkflowContext[Never, list[Message]], + ctx: WorkflowContext[Never, OrchestrationOutput], ) -> bool: """Check if the agent requested termination and yield completion if so. @@ -537,7 +537,7 @@ async def _check_agent_terminate_and_yield( agent_orchestration_output.final_message or "The conversation has been terminated by the agent." ) self._append_messages([self._create_completion_message(final_message)]) - await ctx.yield_output(self._full_conversation) + await ctx.yield_output(OrchestrationOutput(messages=self._full_conversation)) return True return False diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py index 4352a8af47..9359213285 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py @@ -54,7 +54,7 @@ from agent_framework._workflows._workflow_context import WorkflowContext from ._base_group_chat_orchestrator import TerminationCondition -from ._orchestrator_helpers import clean_conversation_for_handoff +from ._orchestration_shared import OrchestrationOutput, filter_tool_contents if sys.version_info >= (3, 12): from typing import override # type: ignore # pragma: no cover @@ -428,7 +428,7 @@ def _handoff_tool() -> None: async def _run_agent_and_emit(self, ctx: WorkflowContext[Any, Any]) -> None: """Override to support handoff.""" incoming_messages = list(self._cache) - cleaned_incoming_messages = clean_conversation_for_handoff(incoming_messages) + cleaned_incoming_messages = filter_tool_contents(incoming_messages) runtime_tool_messages = [ message for message in incoming_messages @@ -492,7 +492,7 @@ async def _run_agent_and_emit(self, ctx: WorkflowContext[Any, Any]) -> None: # Remove function call related content from the agent response for broadcast. # This prevents replaying stale tool artifacts to other agents. - cleaned_response = clean_conversation_for_handoff(response.messages) + cleaned_response = filter_tool_contents(response.messages) # For internal tracking, preserve the full response (including function_calls) # in _full_conversation so that Azure OpenAI can match function_calls with @@ -624,7 +624,7 @@ def _is_handoff_requested(self, response: AgentResponse) -> str | None: return None - async def _check_terminate_and_yield(self, ctx: WorkflowContext[Any, Any]) -> bool: + async def _check_terminate_and_yield(self, ctx: WorkflowContext[Any, OrchestrationOutput]) -> bool: """Check termination conditions and yield completion if met. Args: @@ -641,7 +641,7 @@ async def _check_terminate_and_yield(self, ctx: WorkflowContext[Any, Any]) -> bo terminated = await terminated if terminated: - await ctx.yield_output(self._full_conversation) + await ctx.yield_output(OrchestrationOutput(messages=self._full_conversation)) return True return False diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_magentic.py b/python/packages/orchestrations/agent_framework_orchestrations/_magentic.py index b887d86df3..7b42f0f6cb 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_magentic.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_magentic.py @@ -37,6 +37,7 @@ GroupChatWorkflowContextOutT, ParticipantRegistry, ) +from ._orchestration_shared import OrchestrationOutput if sys.version_info >= (3, 12): from typing import override # type: ignore # pragma: no cover @@ -1055,7 +1056,9 @@ async def _run_inner_loop_helper( if self._magentic_context is None: raise RuntimeError("Context not initialized") # Check limits first - within_limits = await self._check_within_limits_or_complete(cast(WorkflowContext[Never, list[Message]], ctx)) + within_limits = await self._check_within_limits_or_complete( + cast(WorkflowContext[Never, OrchestrationOutput], ctx) + ) if not within_limits: return @@ -1090,7 +1093,7 @@ async def _run_inner_loop_helper( # Check for task completion if self._progress_ledger.is_request_satisfied.answer: logger.info("Magentic Orchestrator: Task completed") - await self._prepare_final_answer(cast(WorkflowContext[Never, list[Message]], ctx)) + await self._prepare_final_answer(cast(WorkflowContext[Never, OrchestrationOutput], ctx)) return # Check for stalling or looping @@ -1114,7 +1117,7 @@ async def _run_inner_loop_helper( if next_speaker not in self._participant_registry.participants: logger.warning(f"Invalid next speaker: {next_speaker}") - await self._prepare_final_answer(cast(WorkflowContext[Never, list[Message]], ctx)) + await self._prepare_final_answer(cast(WorkflowContext[Never, OrchestrationOutput], ctx)) return # Add instruction to conversation (assistant guidance) @@ -1190,7 +1193,7 @@ async def _run_outer_loop( # Start inner loop await self._run_inner_loop(ctx) - async def _prepare_final_answer(self, ctx: WorkflowContext[Never, list[Message]]) -> None: + async def _prepare_final_answer(self, ctx: WorkflowContext[Never, OrchestrationOutput]) -> None: """Prepare the final answer using the manager.""" if self._magentic_context is None: raise RuntimeError("Context not initialized") @@ -1199,11 +1202,11 @@ async def _prepare_final_answer(self, ctx: WorkflowContext[Never, list[Message]] final_answer = await self._manager.prepare_final_answer(self._magentic_context.clone(deep=True)) # Emit a completed event for the workflow - await ctx.yield_output([final_answer]) + await ctx.yield_output(OrchestrationOutput(messages=[*self._magentic_context.chat_history, final_answer])) self._terminated = True - async def _check_within_limits_or_complete(self, ctx: WorkflowContext[Never, list[Message]]) -> bool: + async def _check_within_limits_or_complete(self, ctx: WorkflowContext[Never, OrchestrationOutput]) -> bool: """Check if orchestrator is within operational limits. If limits are exceeded, yield a termination message and mark the workflow as terminated. @@ -1228,14 +1231,18 @@ async def _check_within_limits_or_complete(self, ctx: WorkflowContext[Never, lis logger.error(f"Magentic Orchestrator: Max {limit_type} count reached") # Yield the full conversation with an indication of termination due to limits - await ctx.yield_output([ - *self._magentic_context.chat_history, - Message( - role="assistant", - text=f"Workflow terminated due to reaching maximum {limit_type} count.", - author_name=MAGENTIC_MANAGER_NAME, - ), - ]) + await ctx.yield_output( + OrchestrationOutput( + messages=[ + *self._magentic_context.chat_history, + Message( + role="assistant", + text=f"Workflow terminated due to reaching maximum {limit_type} count.", + author_name=MAGENTIC_MANAGER_NAME, + ), + ] + ) + ) self._terminated = True return False diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_orchestration_state.py b/python/packages/orchestrations/agent_framework_orchestrations/_orchestration_shared.py similarity index 54% rename from python/packages/orchestrations/agent_framework_orchestrations/_orchestration_state.py rename to python/packages/orchestrations/agent_framework_orchestrations/_orchestration_shared.py index e8f8a81080..8c4786eea3 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_orchestration_state.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_orchestration_shared.py @@ -1,11 +1,5 @@ # Copyright (c) Microsoft. All rights reserved. -"""Unified state management for group chat orchestrators. - -Provides OrchestrationState dataclass for standardized checkpoint serialization -across GroupChat, Handoff, and Magentic patterns. -""" - from __future__ import annotations from dataclasses import dataclass, field @@ -14,6 +8,17 @@ from agent_framework._types import Message +@dataclass +class OrchestrationOutput: + """Standardized output format for orchestrations. + + Attributes: + messages: List of messages representing the full conversation of the orchestration, including all agent turns. + """ + + messages: list[Message] + + def _new_chat_message_list() -> list[Message]: """Factory function for typed empty Message list. @@ -91,3 +96,61 @@ def from_dict(cls, data: dict[str, Any]) -> OrchestrationState: metadata=dict(data.get("metadata", {})), task=task, ) + + +def filter_tool_contents(conversation: list[Message]) -> list[Message]: + """Keep only plain text chat history for handoff routing. + + Handoff executors must not replay prior tool-control artifacts (function calls, + tool outputs, approval payloads) into future model turns, or providers may reject + the next request due to unmatched tool-call state. + + This helper builds a text-only copy of the conversation: + - Drops all non-text content from every message. + - Drops messages with no remaining text content. + - Preserves original roles and author names for retained text messages. + """ + cleaned: list[Message] = [] + for msg in conversation: + # Keep only plain text history for handoff routing. Tool-control content + # (function_call/function_result/approval payloads) is runtime-only and + # must not be replayed in future model turns. + text_parts = [content.text for content in msg.contents if content.type == "text" and content.text] + if not text_parts: + continue + + msg_copy = Message( + role=msg.role, + text=" ".join(text_parts), + author_name=msg.author_name, + additional_properties=dict(msg.additional_properties) if msg.additional_properties else None, + ) + cleaned.append(msg_copy) + + return cleaned + + +def create_completion_message( + *, + text: str | None = None, + author_name: str, + reason: str = "completed", +) -> Message: + """Create a standardized completion message. + + Simple helper to avoid duplicating completion message creation. + + Args: + text: Message text, or None to generate default + author_name: Author/orchestrator name + reason: Reason for completion (for default text generation) + + Returns: + Message with assistant role + """ + message_text = text or f"Conversation {reason}." + return Message( + role="assistant", + text=message_text, + author_name=author_name, + ) diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_orchestrator_helpers.py b/python/packages/orchestrations/agent_framework_orchestrations/_orchestrator_helpers.py deleted file mode 100644 index 1e13e4969a..0000000000 --- a/python/packages/orchestrations/agent_framework_orchestrations/_orchestrator_helpers.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""Shared orchestrator utilities for group chat patterns. - -This module provides simple, reusable functions for common orchestration tasks. -No inheritance required - just import and call. -""" - -import logging - -from agent_framework._types import Message - -logger = logging.getLogger(__name__) - - -def clean_conversation_for_handoff(conversation: list[Message]) -> list[Message]: - """Keep only plain text chat history for handoff routing. - - Handoff executors must not replay prior tool-control artifacts (function calls, - tool outputs, approval payloads) into future model turns, or providers may reject - the next request due to unmatched tool-call state. - - This helper builds a text-only copy of the conversation: - - Drops all non-text content from every message. - - Drops messages with no remaining text content. - - Preserves original roles and author names for retained text messages. - """ - cleaned: list[Message] = [] - for msg in conversation: - # Keep only plain text history for handoff routing. Tool-control content - # (function_call/function_result/approval payloads) is runtime-only and - # must not be replayed in future model turns. - text_parts = [content.text for content in msg.contents if content.type == "text" and content.text] - if not text_parts: - continue - - msg_copy = Message( - role=msg.role, - text=" ".join(text_parts), - author_name=msg.author_name, - additional_properties=dict(msg.additional_properties) if msg.additional_properties else None, - ) - cleaned.append(msg_copy) - - return cleaned - - -def create_completion_message( - *, - text: str | None = None, - author_name: str, - reason: str = "completed", -) -> Message: - """Create a standardized completion message. - - Simple helper to avoid duplicating completion message creation. - - Args: - text: Message text, or None to generate default - author_name: Author/orchestrator name - reason: Reason for completion (for default text generation) - - Returns: - Message with assistant role - """ - message_text = text or f"Conversation {reason}." - return Message( - role="assistant", - text=message_text, - author_name=author_name, - ) diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_sequential.py b/python/packages/orchestrations/agent_framework_orchestrations/_sequential.py index 5ef4f7fe8c..46f2a5b71b 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_sequential.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_sequential.py @@ -8,7 +8,7 @@ - A shared conversation context (list[Message]) is passed along the chain - Agents append their assistant messages to the context - Custom executors can transform or summarize and return a refined context -- The workflow finishes with the final context produced by the last participant +- The workflow finishes with an OrchestrationOutput containing the final conversation Typical wiring: input -> _InputToConversation -> participant1 -> (agent? -> _ResponseToConversation) -> ... -> participantN -> _EndWithConversation @@ -27,8 +27,8 @@ - Agent response adaptation ("to-conversation:"): agents (via AgentExecutor) emit `AgentExecutorResponse`. The adapter converts that to a `list[Message]` using `full_conversation` so original prompts aren't lost when chaining. -- Result output ("end"): yields the final conversation list and the workflow becomes idle - giving a consistent terminal payload shape for both agents and custom executors. +- Result output ("end"): yields an OrchestrationOutput wrapping the final conversation + and the workflow becomes idle, giving a consistent terminal payload for both agents and custom executors. These adapters are first-class executors by design so they are type-checked at edges, observable (ExecutorInvoke/Completed events), and easily testable/reusable. Their IDs are @@ -57,6 +57,7 @@ from agent_framework._workflows._workflow_context import WorkflowContext from ._orchestration_request_info import AgentApprovalExecutor +from ._orchestration_shared import OrchestrationOutput logger = logging.getLogger(__name__) @@ -78,31 +79,31 @@ async def from_messages(self, messages: list[str | Message], ctx: WorkflowContex class _EndWithConversation(Executor): - """Terminates the workflow by emitting the final conversation context.""" + """Terminates the workflow by emitting an OrchestrationOutput containing the final conversation.""" @handler async def end_with_messages( self, conversation: list[Message], - ctx: WorkflowContext[Any, list[Message]], + ctx: WorkflowContext[Any, OrchestrationOutput], ) -> None: """Handler for ending with a list of Message. This is used when the last participant is a custom executor. """ - await ctx.yield_output(list(conversation)) + await ctx.yield_output(OrchestrationOutput(messages=list(conversation))) @handler async def end_with_agent_executor_response( self, response: AgentExecutorResponse, - ctx: WorkflowContext[Any, list[Message] | None], + ctx: WorkflowContext[Any, OrchestrationOutput], ) -> None: """Handle case where last participant is an agent. The agent is wrapped by AgentExecutor and emits AgentExecutorResponse. """ - await ctx.yield_output(response.full_conversation) + await ctx.yield_output(OrchestrationOutput(messages=list(response.full_conversation or []))) class SequentialBuilder: @@ -113,7 +114,7 @@ class SequentialBuilder: - The workflow wires participants in order, passing a list[Message] down the chain - Agents append their assistant messages to the conversation - Custom executors can transform/summarize and return a list[Message] - - The final output is the conversation produced by the last participant + - The final output is an OrchestrationOutput containing the conversation from the last participant Usage: @@ -252,7 +253,7 @@ def build(self) -> Workflow: route through a request info interceptor, then convert response to conversation via _ResponseToConversation - Else (custom Executor): pass conversation directly to the executor - - _EndWithConversation yields the final conversation and the workflow becomes idle + - _EndWithConversation yields an OrchestrationOutput and the workflow becomes idle """ # Internal nodes input_conv = _InputToConversation(id="input-conversation") diff --git a/python/packages/orchestrations/tests/test_concurrent.py b/python/packages/orchestrations/tests/test_concurrent.py index 8712aae3fd..13390a04d7 100644 --- a/python/packages/orchestrations/tests/test_concurrent.py +++ b/python/packages/orchestrations/tests/test_concurrent.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. -from typing import Any, cast +from typing import cast import pytest from agent_framework import ( @@ -17,6 +17,8 @@ from agent_framework.orchestrations import ConcurrentBuilder from typing_extensions import Never +from agent_framework_orchestrations._orchestration_shared import OrchestrationOutput + class _FakeAgentExec(Executor): """Test executor that mimics an agent by emitting an AgentExecutorResponse. @@ -58,18 +60,18 @@ async def test_concurrent_default_aggregator_emits_single_user_and_assistants() wf = ConcurrentBuilder(participants=[e1, e2, e3]).build() completed = False - output: list[Message] | None = None + output: OrchestrationOutput | None = None async for ev in wf.run("prompt: hello world", stream=True): if ev.type == "status" and ev.state == WorkflowRunState.IDLE: completed = True elif ev.type == "output": - output = cast(list[Message], ev.data) + output = cast(OrchestrationOutput, ev.data) if completed and output is not None: break assert completed assert output is not None - messages: list[Message] = output + messages: list[Message] = output.messages # Expect one user message + one assistant message per participant assert len(messages) == 1 + 3 @@ -86,68 +88,67 @@ async def test_concurrent_custom_aggregator_callback_is_used() -> None: e1 = _FakeAgentExec("agentA", "One") e2 = _FakeAgentExec("agentB", "Two") - async def summarize(results: list[AgentExecutorResponse]) -> str: + async def summarize(results: list[AgentExecutorResponse]) -> list[Message]: texts: list[str] = [] for r in results: msgs: list[Message] = r.agent_response.messages texts.append(msgs[-1].text if msgs else "") - return " | ".join(sorted(texts)) + return [Message(role="assistant", text=" | ".join(sorted(texts)))] wf = ConcurrentBuilder(participants=[e1, e2]).with_aggregator(summarize).build() completed = False - output: str | None = None + output: OrchestrationOutput | None = None async for ev in wf.run("prompt: custom", stream=True): if ev.type == "status" and ev.state == WorkflowRunState.IDLE: completed = True elif ev.type == "output": - output = cast(str, ev.data) + output = cast(OrchestrationOutput, ev.data) if completed and output is not None: break assert completed assert output is not None - # Custom aggregator returns a string payload - assert isinstance(output, str) - assert output == "One | Two" + assert isinstance(output, OrchestrationOutput) + assert output.messages[-1].text == "One | Two" async def test_concurrent_custom_aggregator_sync_callback_is_used() -> None: e1 = _FakeAgentExec("agentA", "One") e2 = _FakeAgentExec("agentB", "Two") - # Sync callback with ctx parameter (should run via asyncio.to_thread) - def summarize_sync(results: list[AgentExecutorResponse], _ctx: WorkflowContext[Any]) -> str: # type: ignore[unused-argument] + # Sync callback (should run via asyncio.to_thread) + def summarize_sync(results: list[AgentExecutorResponse]) -> list[Message]: texts: list[str] = [] for r in results: msgs: list[Message] = r.agent_response.messages texts.append(msgs[-1].text if msgs else "") - return " | ".join(sorted(texts)) + return [Message(role="assistant", text=" | ".join(sorted(texts)))] wf = ConcurrentBuilder(participants=[e1, e2]).with_aggregator(summarize_sync).build() completed = False - output: str | None = None + output: OrchestrationOutput | None = None async for ev in wf.run("prompt: custom sync", stream=True): if ev.type == "status" and ev.state == WorkflowRunState.IDLE: completed = True elif ev.type == "output": - output = cast(str, ev.data) + output = cast(OrchestrationOutput, ev.data) if completed and output is not None: break assert completed assert output is not None - assert isinstance(output, str) - assert output == "One | Two" + assert isinstance(output, OrchestrationOutput) + assert output.messages[-1].text == "One | Two" def test_concurrent_custom_aggregator_uses_callback_name_for_id() -> None: e1 = _FakeAgentExec("agentA", "One") e2 = _FakeAgentExec("agentB", "Two") - def summarize(results: list[AgentExecutorResponse]) -> str: # type: ignore[override] - return str(len(results)) + def summarize(results: list[AgentExecutorResponse]) -> list[Message]: + return [Message(role="assistant", text=str(len(results)))] wf = ConcurrentBuilder(participants=[e1, e2]).with_aggregator(summarize).build() @@ -193,8 +194,8 @@ async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowCon def test_concurrent_builder_rejects_multiple_calls_to_with_aggregator() -> None: """Test that multiple calls to .with_aggregator() raises an error.""" - def summarize(results: list[AgentExecutorResponse]) -> str: # type: ignore[override] - return str(len(results)) + def summarize(results: list[AgentExecutorResponse]) -> list[Message]: + return [Message(role="assistant", text=str(len(results)))] with pytest.raises(ValueError, match=r"with_aggregator\(\) has already been called"): ( @@ -215,7 +216,7 @@ async def test_concurrent_checkpoint_resume_round_trip() -> None: wf = ConcurrentBuilder(participants=list(participants), checkpoint_storage=storage).build() - baseline_output: list[Message] | None = None + baseline_output: OrchestrationOutput | None = None async for ev in wf.run("checkpoint concurrent", stream=True): if ev.type == "output": baseline_output = ev.data # type: ignore[assignment] @@ -236,7 +237,7 @@ async def test_concurrent_checkpoint_resume_round_trip() -> None: ) wf_resume = ConcurrentBuilder(participants=list(resumed_participants), checkpoint_storage=storage).build() - resumed_output: list[Message] | None = None + resumed_output: OrchestrationOutput | None = None async for ev in wf_resume.run(checkpoint_id=resume_checkpoint.checkpoint_id, stream=True): if ev.type == "output": resumed_output = ev.data # type: ignore[assignment] @@ -247,8 +248,8 @@ async def test_concurrent_checkpoint_resume_round_trip() -> None: break assert resumed_output is not None - assert [m.role for m in resumed_output] == [m.role for m in baseline_output] - assert [m.text for m in resumed_output] == [m.text for m in baseline_output] + assert [m.role for m in resumed_output.messages] == [m.role for m in baseline_output.messages] + assert [m.text for m in resumed_output.messages] == [m.text for m in baseline_output.messages] async def test_concurrent_checkpoint_runtime_only() -> None: @@ -258,7 +259,7 @@ async def test_concurrent_checkpoint_runtime_only() -> None: agents = [_FakeAgentExec(id="agent1", reply_text="A1"), _FakeAgentExec(id="agent2", reply_text="A2")] wf = ConcurrentBuilder(participants=agents).build() - baseline_output: list[Message] | None = None + baseline_output: OrchestrationOutput | None = None async for ev in wf.run("runtime checkpoint test", checkpoint_storage=storage, stream=True): if ev.type == "output": baseline_output = ev.data # type: ignore[assignment] @@ -278,7 +279,7 @@ async def test_concurrent_checkpoint_runtime_only() -> None: resumed_agents = [_FakeAgentExec(id="agent1", reply_text="A1"), _FakeAgentExec(id="agent2", reply_text="A2")] wf_resume = ConcurrentBuilder(participants=resumed_agents).build() - resumed_output: list[Message] | None = None + resumed_output: OrchestrationOutput | None = None async for ev in wf_resume.run( checkpoint_id=resume_checkpoint.checkpoint_id, checkpoint_storage=storage, stream=True ): @@ -291,7 +292,7 @@ async def test_concurrent_checkpoint_runtime_only() -> None: break assert resumed_output is not None - assert [m.role for m in resumed_output] == [m.role for m in baseline_output] + assert [m.role for m in resumed_output.messages] == [m.role for m in baseline_output.messages] async def test_concurrent_checkpoint_runtime_overrides_buildtime() -> None: @@ -307,7 +308,7 @@ async def test_concurrent_checkpoint_runtime_overrides_buildtime() -> None: agents = [_FakeAgentExec(id="agent1", reply_text="A1"), _FakeAgentExec(id="agent2", reply_text="A2")] wf = ConcurrentBuilder(participants=agents, checkpoint_storage=buildtime_storage).build() - baseline_output: list[Message] | None = None + baseline_output: OrchestrationOutput | None = None async for ev in wf.run("override test", checkpoint_storage=runtime_storage, stream=True): if ev.type == "output": baseline_output = ev.data # type: ignore[assignment] diff --git a/python/packages/orchestrations/tests/test_group_chat.py b/python/packages/orchestrations/tests/test_group_chat.py index 7550f820c7..eb9c99ef08 100644 --- a/python/packages/orchestrations/tests/test_group_chat.py +++ b/python/packages/orchestrations/tests/test_group_chat.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. from collections.abc import AsyncIterable, Awaitable, Callable, Sequence -from typing import Any, cast +from typing import Any, Literal, cast, overload import pytest from agent_framework import ( @@ -9,19 +9,18 @@ AgentExecutorResponse, AgentResponse, AgentResponseUpdate, + AgentRunInputs, AgentSession, BaseAgent, - ChatResponse, - ChatResponseUpdate, Content, Message, + ResponseStream, WorkflowEvent, WorkflowRunState, ) from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage from agent_framework.orchestrations import ( AgentRequestInfoResponse, - BaseGroupChatOrchestrator, GroupChatBuilder, GroupChatState, MagenticContext, @@ -30,25 +29,48 @@ MagenticProgressLedgerItem, ) +from agent_framework_orchestrations._base_group_chat_orchestrator import BaseGroupChatOrchestrator +from agent_framework_orchestrations._orchestration_shared import OrchestrationOutput + class StubAgent(BaseAgent): def __init__(self, agent_name: str, reply_text: str, **kwargs: Any) -> None: super().__init__(name=agent_name, description=f"Stub agent {agent_name}", **kwargs) self._reply_text = reply_text - def run( # type: ignore[override] + @overload + def run( self, - messages: str | Content | Message | Sequence[str | Content | Message] | None = None, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... + + @overload + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + + def run( + self, + messages: AgentRunInputs | None = None, *, stream: bool = False, session: AgentSession | None = None, **kwargs: Any, - ) -> Awaitable[AgentResponse] | AsyncIterable[AgentResponseUpdate]: + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: if stream: - return self._run_stream_impl() + return ResponseStream(self._run_stream_impl(), finalizer=AgentResponse.from_updates) return self._run_impl() - async def _run_impl(self) -> AgentResponse: + async def _run_impl(self) -> AgentResponse[Any]: response = Message(role="assistant", text=self._reply_text, author_name=self.name) return AgentResponse(messages=[response]) @@ -61,20 +83,18 @@ async def _run_stream_impl(self) -> AsyncIterable[AgentResponseUpdate]: class MockChatClient: """Mock chat client that raises NotImplementedError for all methods.""" - additional_properties: dict[str, Any] + additional_properties: dict[str, Any] = {} - async def get_response( - self, messages: Any, stream: bool = False, **kwargs: Any - ) -> ChatResponse | AsyncIterable[ChatResponseUpdate]: + def get_response(self, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError class StubManagerAgent(Agent): def __init__(self) -> None: - super().__init__(client=MockChatClient(), name="manager_agent", description="Stub manager") + super().__init__(client=cast(Any, MockChatClient()), name="manager_agent", description="Stub manager") self._call_count = 0 - async def run( + async def run( # type: ignore[override] self, messages: str | Content | Message | Sequence[str | Content | Message] | None = None, *, @@ -84,7 +104,6 @@ async def run( if self._call_count == 0: self._call_count += 1 # First call: select the agent (using AgentOrchestrationOutput format) - payload = {"terminate": False, "reason": "Selecting agent", "next_speaker": "agent", "final_message": None} return AgentResponse( messages=[ Message( @@ -96,16 +115,9 @@ async def run( author_name=self.name, ) ], - value=payload, ) # Second call: terminate - payload = { - "terminate": True, - "reason": "Task complete", - "next_speaker": None, - "final_message": "agent manager final", - } return AgentResponse( messages=[ Message( @@ -117,7 +129,6 @@ async def run( author_name=self.name, ) ], - value=payload, ) @@ -125,10 +136,12 @@ class ConcatenatedJsonManagerAgent(Agent): """Manager agent that emits concatenated JSON in a single assistant message.""" def __init__(self) -> None: - super().__init__(client=MockChatClient(), name="concat_manager", description="Concatenated JSON manager") + super().__init__( + client=cast(Any, MockChatClient()), name="concat_manager", description="Concatenated JSON manager" + ) self._call_count = 0 - async def run( + async def run( # type: ignore[override] self, messages: str | Content | Message | Sequence[str | Content | Message] | None = None, *, @@ -230,17 +243,17 @@ async def test_group_chat_builder_basic_flow() -> None: orchestrator_name="manager", ).build() - outputs: list[list[Message]] = [] + outputs: list[OrchestrationOutput] = [] async for event in workflow.run("coordinate task", stream=True): if event.type == "output": data = event.data - if isinstance(data, list): - outputs.append(cast(list[Message], data)) + if isinstance(data, OrchestrationOutput): + outputs.append(data) assert len(outputs) == 1 - assert len(outputs[0]) >= 1 + assert len(outputs[0].messages) >= 1 # Check that both agents contributed - authors = {msg.author_name for msg in outputs[0] if msg.author_name in ["alpha", "beta"]} + authors = {msg.author_name for msg in outputs[0].messages if msg.author_name in ["alpha", "beta"]} assert len(authors) == 2 @@ -275,15 +288,15 @@ async def test_agent_manager_handles_concatenated_json_output() -> None: orchestrator_agent=manager, ).build() - outputs: list[list[Message]] = [] + outputs: list[OrchestrationOutput] = [] async for event in workflow.run("coordinate task", stream=True): if event.type == "output": data = event.data - if isinstance(data, list): - outputs.append(cast(list[Message], data)) + if isinstance(data, OrchestrationOutput): + outputs.append(data) assert outputs - conversation = outputs[-1] + conversation = outputs[-1].messages assert any(msg.author_name == "agent" and msg.text == "worker response" for msg in conversation) assert conversation[-1].author_name == manager.name assert conversation[-1].text == "concatenated manager final" @@ -345,18 +358,42 @@ class AgentWithoutName(BaseAgent): def __init__(self) -> None: super().__init__(name="", description="test") + @overload def run( - self, messages: Any = None, *, stream: bool = False, session: Any = None, **kwargs: Any - ) -> AgentResponse | AsyncIterable[AgentResponseUpdate]: - if stream: - - async def _stream() -> AsyncIterable[AgentResponseUpdate]: - yield AgentResponseUpdate(contents=[]) + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... + + @overload + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... - return _stream() + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: bool = False, + session: AgentSession | None = None, + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: + if stream: + return ResponseStream(self._run_stream(), finalizer=AgentResponse.from_updates) return self._run_impl() - async def _run_impl(self) -> AgentResponse: + async def _run_stream(self) -> AsyncIterable[AgentResponseUpdate]: + yield AgentResponseUpdate(contents=[]) + + async def _run_impl(self) -> AgentResponse[Any]: return AgentResponse(messages=[]) agent = AgentWithoutName() @@ -392,17 +429,17 @@ def selector(state: GroupChatState) -> str: selection_func=selector, ).build() - outputs: list[list[Message]] = [] + outputs: list[OrchestrationOutput] = [] async for event in workflow.run("test task", stream=True): if event.type == "output": data = event.data - if isinstance(data, list): - outputs.append(cast(list[Message], data)) + if isinstance(data, OrchestrationOutput): + outputs.append(data) # Should have terminated due to max_rounds, expect at least one output assert len(outputs) >= 1 # The final message in the conversation should be about round limit - conversation = outputs[-1] + conversation = outputs[-1].messages assert len(conversation) >= 1 final_output = conversation[-1] assert "maximum number of rounds" in final_output.text.lower() @@ -425,15 +462,15 @@ def termination_condition(conversation: list[Message]) -> bool: selection_func=selector, ).build() - outputs: list[list[Message]] = [] + outputs: list[OrchestrationOutput] = [] async for event in workflow.run("test task", stream=True): if event.type == "output": data = event.data - if isinstance(data, list): - outputs.append(cast(list[Message], data)) + if isinstance(data, OrchestrationOutput): + outputs.append(data) assert outputs, "Expected termination to yield output" - conversation = outputs[-1] + conversation = outputs[-1].messages agent_replies = [msg for msg in conversation if msg.author_name == "agent" and msg.role == "assistant"] assert len(agent_replies) == 2 final_output = conversation[-1] @@ -451,15 +488,15 @@ async def test_termination_condition_agent_manager_finalizes(self) -> None: orchestrator_agent=manager, ).build() - outputs: list[list[Message]] = [] + outputs: list[OrchestrationOutput] = [] async for event in workflow.run("test task", stream=True): if event.type == "output": data = event.data - if isinstance(data, list): - outputs.append(cast(list[Message], data)) + if isinstance(data, OrchestrationOutput): + outputs.append(data) assert outputs, "Expected termination to yield output" - conversation = outputs[-1] + conversation = outputs[-1].messages assert conversation[-1].text == BaseGroupChatOrchestrator.TERMINATION_CONDITION_MET_MESSAGE assert conversation[-1].author_name == manager.name @@ -497,12 +534,12 @@ def selector(state: GroupChatState) -> str: selection_func=selector, ).build() - outputs: list[list[Message]] = [] + outputs: list[OrchestrationOutput] = [] async for event in workflow.run("test task", stream=True): if event.type == "output": data = event.data - if isinstance(data, list): - outputs.append(cast(list[Message], data)) + if isinstance(data, OrchestrationOutput): + outputs.append(data) assert len(outputs) == 1 # Should complete normally @@ -538,12 +575,12 @@ def selector(state: GroupChatState) -> str: workflow = GroupChatBuilder(participants=[agent], max_rounds=1, selection_func=selector).build() - outputs: list[list[Message]] = [] + outputs: list[OrchestrationOutput] = [] async for event in workflow.run("test string", stream=True): if event.type == "output": data = event.data - if isinstance(data, list): - outputs.append(cast(list[Message], data)) + if isinstance(data, OrchestrationOutput): + outputs.append(data) assert len(outputs) == 1 @@ -561,12 +598,12 @@ def selector(state: GroupChatState) -> str: workflow = GroupChatBuilder(participants=[agent], max_rounds=1, selection_func=selector).build() - outputs: list[list[Message]] = [] + outputs: list[OrchestrationOutput] = [] async for event in workflow.run(task_message, stream=True): if event.type == "output": data = event.data - if isinstance(data, list): - outputs.append(cast(list[Message], data)) + if isinstance(data, OrchestrationOutput): + outputs.append(data) assert len(outputs) == 1 @@ -587,12 +624,12 @@ def selector(state: GroupChatState) -> str: workflow = GroupChatBuilder(participants=[agent], max_rounds=1, selection_func=selector).build() - outputs: list[list[Message]] = [] + outputs: list[OrchestrationOutput] = [] async for event in workflow.run(conversation, stream=True): if event.type == "output": data = event.data - if isinstance(data, list): - outputs.append(cast(list[Message], data)) + if isinstance(data, OrchestrationOutput): + outputs.append(data) assert len(outputs) == 1 @@ -617,17 +654,17 @@ def selector(state: GroupChatState) -> str: selection_func=selector, ).build() - outputs: list[list[Message]] = [] + outputs: list[OrchestrationOutput] = [] async for event in workflow.run("test", stream=True): if event.type == "output": data = event.data - if isinstance(data, list): - outputs.append(cast(list[Message], data)) + if isinstance(data, OrchestrationOutput): + outputs.append(data) # Should have at least one output (the round limit message) assert len(outputs) >= 1 # The last message in the conversation should be about round limit - conversation = outputs[-1] + conversation = outputs[-1].messages assert len(conversation) >= 1 final_output = conversation[-1] assert "maximum number of rounds" in final_output.text.lower() @@ -650,17 +687,17 @@ def selector(state: GroupChatState) -> str: selection_func=selector, ).build() - outputs: list[list[Message]] = [] + outputs: list[OrchestrationOutput] = [] async for event in workflow.run("test", stream=True): if event.type == "output": data = event.data - if isinstance(data, list): - outputs.append(cast(list[Message], data)) + if isinstance(data, OrchestrationOutput): + outputs.append(data) # Should have at least one output (the round limit message) assert len(outputs) >= 1 # The last message in the conversation should be about round limit - conversation = outputs[-1] + conversation = outputs[-1].messages assert len(conversation) >= 1 final_output = conversation[-1] assert "maximum number of rounds" in final_output.text.lower() @@ -676,10 +713,10 @@ async def test_group_chat_checkpoint_runtime_only() -> None: wf = GroupChatBuilder(participants=[agent_a, agent_b], max_rounds=2, selection_func=selector).build() - baseline_output: list[Message] | None = None + baseline_output: OrchestrationOutput | None = None async for ev in wf.run("runtime checkpoint test", checkpoint_storage=storage, stream=True): if ev.type == "output": - baseline_output = cast(list[Message], ev.data) if isinstance(ev.data, list) else None # type: ignore + baseline_output = ev.data if isinstance(ev.data, OrchestrationOutput) else None # type: ignore if ev.type == "status" and ev.state in ( WorkflowRunState.IDLE, WorkflowRunState.IDLE_WITH_PENDING_REQUESTS, @@ -712,10 +749,10 @@ async def test_group_chat_checkpoint_runtime_overrides_buildtime() -> None: checkpoint_storage=buildtime_storage, selection_func=selector, ).build() - baseline_output: list[Message] | None = None + baseline_output: OrchestrationOutput | None = None async for ev in wf.run("override test", checkpoint_storage=runtime_storage, stream=True): if ev.type == "output": - baseline_output = cast(list[Message], ev.data) if isinstance(ev.data, list) else None # type: ignore + baseline_output = ev.data if isinstance(ev.data, OrchestrationOutput) else None # type: ignore if ev.type == "status" and ev.state in ( WorkflowRunState.IDLE, WorkflowRunState.IDLE_WITH_PENDING_REQUESTS, @@ -891,10 +928,10 @@ class DynamicManagerAgent(Agent): """Manager agent that dynamically selects from available participants.""" def __init__(self) -> None: - super().__init__(client=MockChatClient(), name="dynamic_manager", description="Dynamic manager") + super().__init__(client=cast(Any, MockChatClient()), name="dynamic_manager", description="Dynamic manager") self._call_count = 0 - async def run( + async def run( # type: ignore[override] self, messages: str | Content | Message | Sequence[str | Content | Message] | None = None, *, @@ -903,12 +940,6 @@ async def run( ) -> AgentResponse: if self._call_count == 0: self._call_count += 1 - payload = { - "terminate": False, - "reason": "Selecting alpha", - "next_speaker": "alpha", - "final_message": None, - } return AgentResponse( messages=[ Message( @@ -920,15 +951,8 @@ async def run( author_name=self.name, ) ], - value=payload, ) - payload = { - "terminate": True, - "reason": "Task complete", - "next_speaker": None, - "final_message": "dynamic manager final", - } return AgentResponse( messages=[ Message( @@ -940,7 +964,6 @@ async def run( author_name=self.name, ) ], - value=payload, ) def agent_factory() -> Agent: @@ -964,11 +987,9 @@ def agent_factory() -> Agent: assert len(outputs) == 1 # The DynamicManagerAgent terminates after second call with final_message final_messages = outputs[0].data - assert isinstance(final_messages, list) + assert isinstance(final_messages, OrchestrationOutput) assert any( - msg.text == "dynamic manager final" - for msg in cast(list[Message], final_messages) - if msg.author_name == "dynamic_manager" + msg.text == "dynamic manager final" for msg in final_messages.messages if msg.author_name == "dynamic_manager" ) diff --git a/python/packages/orchestrations/tests/test_handoff.py b/python/packages/orchestrations/tests/test_handoff.py index 43c2f9153a..2739e97f6c 100644 --- a/python/packages/orchestrations/tests/test_handoff.py +++ b/python/packages/orchestrations/tests/test_handoff.py @@ -30,7 +30,7 @@ _AutoHandoffMiddleware, # pyright: ignore[reportPrivateUsage] get_handoff_tool_name, ) -from agent_framework_orchestrations._orchestrator_helpers import clean_conversation_for_handoff +from agent_framework_orchestrations._orchestration_shared import OrchestrationOutput, filter_tool_contents class MockChatClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]): @@ -83,10 +83,10 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: contents = _build_reply_contents(self._name, self._handoff_to, self._next_call_id()) yield ChatResponseUpdate(contents=contents, role="assistant", finish_reason="stop") - def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse: + def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse[Any]: response_format = options.get("response_format") output_format_type = response_format if isinstance(response_format, type) else None - return ChatResponse.from_updates(updates, output_format_type=output_format_type) + return ChatResponse.from_updates(updates, output_format_type=output_format_type) # pyright: ignore[reportUnknownVariableType, reportReturnType] return ResponseStream(_stream(), finalizer=_finalize) @@ -692,7 +692,7 @@ def test_handoff_clone_disables_provider_side_storage() -> None: executor = workflow.executors[resolve_agent_id(triage)] assert isinstance(executor, HandoffAgentExecutor) - assert executor._agent.default_options.get("store") is False + assert executor._agent.default_options.get("store") is False # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAttributeAccessIssue] async def test_handoff_clears_stale_service_session_id_before_run() -> None: @@ -703,14 +703,14 @@ async def test_handoff_clears_stale_service_session_id_before_run() -> None: triage_executor = workflow.executors[resolve_agent_id(triage)] assert isinstance(triage_executor, HandoffAgentExecutor) - triage_executor._session.service_session_id = "resp_stale_value" + triage_executor._session.service_session_id = "resp_stale_value" # pyright: ignore[reportPrivateUsage] await _drain(workflow.run("My order is damaged", stream=True)) - assert triage_executor._session.service_session_id is None + assert triage_executor._session.service_session_id is None # pyright: ignore[reportPrivateUsage] -def test_clean_conversation_for_handoff_keeps_text_only_history() -> None: +def test_filter_tool_contents_keeps_text_only_history() -> None: """Tool-control messages must be excluded from persisted handoff history.""" function_call = Content.from_function_call( call_id="handoff-call-1", @@ -740,7 +740,7 @@ def test_clean_conversation_for_handoff_keeps_text_only_history() -> None: ), ] - cleaned = clean_conversation_for_handoff(conversation) + cleaned = filter_tool_contents(conversation) assert [message.role for message in cleaned] == ["user", "assistant"] assert [message.text for message in cleaned] == [ "My order arrived damaged.", @@ -756,7 +756,7 @@ def test_persist_missing_approved_function_results_handles_runtime_and_fallback_ call_with_runtime_result = "call-runtime-result" call_with_approval_only = "call-approval-only" - executor._full_conversation = [ + executor._full_conversation = [ # pyright: ignore[reportPrivateUsage] Message( role="assistant", contents=[ @@ -779,9 +779,9 @@ def test_persist_missing_approved_function_results_handles_runtime_and_fallback_ Message(role="user", contents=[approval_response]), ] - executor._persist_missing_approved_function_results(runtime_tool_messages=runtime_messages, response_messages=[]) + executor._persist_missing_approved_function_results(runtime_tool_messages=runtime_messages, response_messages=[]) # pyright: ignore[reportPrivateUsage] - persisted_tool_messages = [message for message in executor._full_conversation if message.role == "tool"] + persisted_tool_messages = [message for message in executor._full_conversation if message.role == "tool"] # pyright: ignore[reportPrivateUsage] assert persisted_tool_messages persisted_results = [ content @@ -827,9 +827,11 @@ async def test_autonomous_mode_yields_output_without_user_request(): assert outputs, "Autonomous mode should yield a workflow output" final_conversation = outputs[-1].data - assert isinstance(final_conversation, list) - conversation_list = cast(list[Message], final_conversation) - assert any(msg.role == "assistant" and (msg.text or "").startswith("specialist reply") for msg in conversation_list) + assert isinstance(final_conversation, OrchestrationOutput) + assert any( + msg.role == "assistant" and (msg.text or "").startswith("specialist reply") + for msg in final_conversation.messages + ) async def test_autonomous_mode_resumes_user_input_on_turn_limit(): @@ -897,9 +899,8 @@ async def async_termination(conv: list[Message]) -> bool: assert len(outputs) == 1 final_conversation = outputs[0].data - assert isinstance(final_conversation, list) - final_conv_list = cast(list[Message], final_conversation) - user_messages = [msg for msg in final_conv_list if msg.role == "user"] + assert isinstance(final_conversation, OrchestrationOutput) + user_messages = [msg for msg in final_conversation.messages if msg.role == "user"] assert len(user_messages) == 2 assert termination_call_count > 0 @@ -955,7 +956,7 @@ async def _get() -> ChatResponse: outputs = [event for event in events if event.type == "output"] assert outputs - conversation_outputs = [event for event in outputs if isinstance(event.data, list)] + conversation_outputs = [event for event in outputs if isinstance(event.data, OrchestrationOutput)] assert len(conversation_outputs) == 1 @@ -1098,22 +1099,24 @@ def test_handoff_builder_rejects_non_agent_supports_agent_run(): from agent_framework import AgentResponse, AgentSession, SupportsAgentRun class FakeAgentRun: - def __init__(self, id, name): + def __init__(self, id: str, name: str) -> None: self.id = id self.name = name self.description = "d" - async def run(self, messages=None, *, stream=False, session=None, **kwargs): + async def run( + self, messages: Any = None, *, stream: bool = False, session: Any = None, **kwargs: Any + ) -> AgentResponse: return AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("ok")])]) - def create_session(self, **kwargs): + def create_session(self, **kwargs: Any) -> AgentSession: return AgentSession() - def get_session(self, *, service_session_id, **kwargs): + def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession: return AgentSession(service_session_id=service_session_id) fake = FakeAgentRun("a", "A") assert isinstance(fake, SupportsAgentRun) with pytest.raises(TypeError, match="Participants must be Agent instances"): - HandoffBuilder().participants([fake]) + HandoffBuilder().participants([fake]) # pyright: ignore[reportArgumentType] diff --git a/python/packages/orchestrations/tests/test_magentic.py b/python/packages/orchestrations/tests/test_magentic.py index 1857a16ee4..73e4a7ca42 100644 --- a/python/packages/orchestrations/tests/test_magentic.py +++ b/python/packages/orchestrations/tests/test_magentic.py @@ -3,17 +3,19 @@ import sys from collections.abc import AsyncIterable, Awaitable, Sequence from dataclasses import dataclass -from typing import Any, ClassVar, cast +from typing import Any, ClassVar, Literal, cast, overload import pytest from agent_framework import ( AgentResponse, AgentResponseUpdate, + AgentRunInputs, AgentSession, BaseAgent, Content, Executor, Message, + ResponseStream, SupportsAgentRun, Workflow, WorkflowCheckpoint, @@ -36,6 +38,8 @@ StandardMagenticManager, ) +from agent_framework_orchestrations._orchestration_shared import OrchestrationOutput + if sys.version_info >= (3, 12): from typing import override # type: ignore # pragma: no cover else: @@ -148,6 +152,24 @@ def __init__(self, agent_name: str, reply_text: str, **kwargs: Any) -> None: super().__init__(name=agent_name, description=f"Stub agent {agent_name}", **kwargs) self._reply_text = reply_text + @overload + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: Literal[False] = ..., + session: AgentSession | None = None, + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... # type: ignore[override] + @overload + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: Literal[True], + session: AgentSession | None = None, + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... # type: ignore[override] def run( # type: ignore[override] self, messages: str | Content | Message | Sequence[str | Content | Message] | None = None, @@ -155,7 +177,7 @@ def run( # type: ignore[override] stream: bool = False, session: AgentSession | None = None, **kwargs: Any, - ) -> Awaitable[AgentResponse] | AsyncIterable[AgentResponseUpdate]: + ) -> Awaitable[AgentResponse[Any]] | AsyncIterable[AgentResponseUpdate]: if stream: return self._run_stream() @@ -195,8 +217,8 @@ async def test_magentic_builder_returns_workflow_and_runs() -> None: async for event in workflow.run("compose summary", stream=True): if event.type == "output": msg = event.data - if isinstance(msg, list): - outputs.extend(cast(list[Message], msg)) + if isinstance(msg, OrchestrationOutput): + outputs.extend(msg.messages) elif event.type == "magentic_orchestrator": orchestrator_event_count += 1 @@ -250,7 +272,7 @@ async def test_magentic_workflow_plan_review_approval_to_completion(): assert isinstance(req_event.data, MagenticPlanReviewRequest) completed = False - output: list[Message] | None = None + output: OrchestrationOutput | None = None async for ev in wf.run(stream=True, responses={req_event.request_id: req_event.data.approve()}): if ev.type == "status" and ev.state == WorkflowRunState.IDLE: completed = True @@ -261,8 +283,8 @@ async def test_magentic_workflow_plan_review_approval_to_completion(): assert completed assert output is not None - assert isinstance(output, list) - assert all(isinstance(msg, Message) for msg in output) + assert isinstance(output, OrchestrationOutput) + assert all(isinstance(msg, Message) for msg in output.messages) async def test_magentic_plan_review_with_revise(): @@ -337,10 +359,10 @@ async def test_magentic_orchestrator_round_limit_produces_partial_result(): output_event = next((e for e in events if e.type == "output"), None) assert output_event is not None data = output_event.data - assert isinstance(data, list) - assert len(data) > 0 # type: ignore - assert data[-1].role == "assistant" # type: ignore - assert all(isinstance(msg, Message) for msg in data) # type: ignore + assert isinstance(data, OrchestrationOutput) + assert len(data.messages) > 0 + assert data.messages[-1].role == "assistant" + assert all(isinstance(msg, Message) for msg in data.messages) async def test_magentic_checkpoint_resume_round_trip(): @@ -409,14 +431,32 @@ async def test_magentic_checkpoint_resume_round_trip(): class StubManagerAgent(BaseAgent): """Stub agent for testing StandardMagenticManager.""" + @overload + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: Literal[False] = ..., + session: AgentSession | None = None, + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... # type: ignore[override] + @overload def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: Literal[True], + session: AgentSession | None = None, + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... # type: ignore[override] + def run( # type: ignore[override] self, messages: str | Content | Message | Sequence[str | Content | Message] | None = None, *, stream: bool = False, session: Any = None, **kwargs: Any, - ) -> Awaitable[AgentResponse] | AsyncIterable[AgentResponseUpdate]: + ) -> Awaitable[AgentResponse[Any]] | AsyncIterable[AgentResponseUpdate]: if stream: return self._run_stream() @@ -426,11 +466,11 @@ async def _run() -> AgentResponse: return _run() async def _run_stream(self) -> AsyncIterable[AgentResponseUpdate]: - yield AgentResponseUpdate(message_deltas=[Message("assistant", ["ok"])]) + yield AgentResponseUpdate(contents=[Content.from_text(text="ok")]) async def test_standard_manager_plan_and_replan_via_complete_monkeypatch(): - mgr = StandardMagenticManager(StubManagerAgent()) + mgr = StandardMagenticManager(StubManagerAgent()) # type: ignore[arg-type] async def fake_complete_plan(messages: list[Message], **kwargs: Any) -> Message: # Return a different response depending on call order length @@ -460,7 +500,7 @@ async def fake_complete_replan(messages: list[Message], **kwargs: Any) -> Messag async def test_standard_manager_progress_ledger_success_and_error(): - mgr = StandardMagenticManager(agent=StubManagerAgent()) + mgr = StandardMagenticManager(agent=StubManagerAgent()) # type: ignore[arg-type] ctx = MagenticContext(task="task", participant_descriptions={"alice": "desc"}) # Success path: valid JSON @@ -526,7 +566,25 @@ class StubThreadAgent(BaseAgent): def __init__(self, name: str | None = None) -> None: super().__init__(name=name or "agentA") - def run(self, messages=None, *, stream: bool = False, session=None, **kwargs): # type: ignore[override] + @overload + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: Literal[False] = ..., + session: AgentSession | None = None, + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... # type: ignore[override] + @overload + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: Literal[True], + session: AgentSession | None = None, + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... # type: ignore[override] + def run(self, messages=None, *, stream: bool = False, session=None, **kwargs: Any): # type: ignore[override] if stream: return self._run_stream() @@ -554,7 +612,25 @@ def __init__(self) -> None: super().__init__(name="agentA") self.client = StubAssistantsClient() # type name contains 'AssistantsClient' - def run(self, messages=None, *, stream: bool = False, session=None, **kwargs): # type: ignore[override] + @overload + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: Literal[False] = ..., + session: AgentSession | None = None, + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... # type: ignore[override] + @overload + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: Literal[True], + session: AgentSession | None = None, + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... # type: ignore[override] + def run(self, messages=None, *, stream: bool = False, session=None, **kwargs: Any): # type: ignore[override] if stream: return self._run_stream() @@ -753,11 +829,11 @@ async def test_magentic_stall_and_reset_reach_limits(): assert idle_status is not None output_event = next((e for e in events if e.type == "output"), None) assert output_event is not None - assert isinstance(output_event.data, list) - assert all(isinstance(msg, Message) for msg in output_event.data) # type: ignore - assert len(output_event.data) > 0 # type: ignore - assert output_event.data[-1].text is not None # type: ignore - assert output_event.data[-1].text == "Workflow terminated due to reaching maximum reset count." # type: ignore + assert isinstance(output_event.data, OrchestrationOutput) + assert all(isinstance(msg, Message) for msg in output_event.data.messages) + assert len(output_event.data.messages) > 0 + assert output_event.data.messages[-1].text is not None + assert output_event.data.messages[-1].text == "Workflow terminated due to reaching maximum reset count." async def test_magentic_checkpoint_runtime_only() -> None: @@ -1085,14 +1161,32 @@ async def test_standard_manager_propagates_session_to_agent(): class SessionCapturingAgent(BaseAgent): """Agent that records the session passed to each run() call.""" + @overload + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: Literal[False] = ..., + session: AgentSession | None = None, + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... # type: ignore[override] + @overload def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: Literal[True], + session: AgentSession | None = None, + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... # type: ignore[override] + def run( # type: ignore[override] self, messages: str | Content | Message | Sequence[str | Content | Message] | None = None, *, stream: bool = False, session: Any = None, **kwargs: Any, - ) -> Awaitable[AgentResponse] | AsyncIterable[AgentResponseUpdate]: + ) -> Awaitable[AgentResponse[Any]] | AsyncIterable[AgentResponseUpdate]: captured_sessions.append(session) async def _run() -> AgentResponse: @@ -1101,7 +1195,7 @@ async def _run() -> AgentResponse: return _run() agent = SessionCapturingAgent() - mgr = StandardMagenticManager(agent=agent) + mgr = StandardMagenticManager(agent=agent) # type: ignore[arg-type] ctx = MagenticContext(task="task", participant_descriptions={"a": "desc"}) await mgr.plan(ctx.clone()) @@ -1110,35 +1204,35 @@ async def _run() -> AgentResponse: assert len(captured_sessions) == 2 assert all(s is not None for s in captured_sessions), "session must be passed to agent.run()" assert captured_sessions[0] is captured_sessions[1], "same session instance must be reused across calls" - assert captured_sessions[0] is mgr._session + assert captured_sessions[0] is mgr._session # type: ignore[reportPrivateUsage] def test_standard_manager_checkpoint_preserves_session(): """Verify that checkpoint save/restore preserves the manager's session identity.""" agent = StubManagerAgent() - mgr = StandardMagenticManager(agent=agent) - original_session_id = mgr._session.session_id + mgr = StandardMagenticManager(agent=agent) # type: ignore[arg-type] + original_session_id = mgr._session.session_id # type: ignore[reportPrivateUsage] state = mgr.on_checkpoint_save() assert "agent_session" in state # Restore into a fresh manager and verify session_id is preserved - mgr2 = StandardMagenticManager(agent=agent) - assert mgr2._session.session_id != original_session_id + mgr2 = StandardMagenticManager(agent=agent) # type: ignore[arg-type] + assert mgr2._session.session_id != original_session_id # type: ignore[reportPrivateUsage] mgr2.on_checkpoint_restore(state) - assert mgr2._session.session_id == original_session_id + assert mgr2._session.session_id == original_session_id # type: ignore[reportPrivateUsage] def test_standard_manager_checkpoint_restore_empty_state(): """Verify that restoring from a state without agent_session leaves the session intact.""" agent = StubManagerAgent() - mgr = StandardMagenticManager(agent=agent) - original_session = mgr._session + mgr = StandardMagenticManager(agent=agent) # type: ignore[arg-type] + original_session = mgr._session # type: ignore[reportPrivateUsage] original_session_id = original_session.session_id mgr.on_checkpoint_restore({}) - assert mgr._session is original_session - assert mgr._session.session_id == original_session_id + assert mgr._session is original_session # type: ignore[reportPrivateUsage] + assert mgr._session.session_id == original_session_id # type: ignore[reportPrivateUsage] # endregion diff --git a/python/packages/orchestrations/tests/test_orchestration_request_info.py b/python/packages/orchestrations/tests/test_orchestration_request_info.py index 7d0acbc945..d97691e6a7 100644 --- a/python/packages/orchestrations/tests/test_orchestration_request_info.py +++ b/python/packages/orchestrations/tests/test_orchestration_request_info.py @@ -2,16 +2,19 @@ """Unit tests for orchestration request info support.""" -from collections.abc import AsyncIterable -from typing import Any +from collections.abc import AsyncIterable, Awaitable +from typing import Any, Literal, overload from unittest.mock import AsyncMock, MagicMock import pytest from agent_framework import ( AgentResponse, AgentResponseUpdate, + AgentRunInputs, AgentSession, + Content, Message, + ResponseStream, SupportsAgentRun, ) from agent_framework._workflows._agent_executor import AgentExecutorRequest, AgentExecutorResponse @@ -177,47 +180,61 @@ async def test_handle_request_info_response_approval(self): class _TestAgent: """Simple test agent implementation.""" - def __init__(self, id: str, name: str | None = None, description: str | None = None): - self._id = id - self._name = name - self._description = description - - @property - def id(self) -> str: - return self._id - - @property - def name(self) -> str | None: - return self._name - - @property - def display_name(self) -> str: - return self._name or self._id + id: str + name: str | None + description: str | None - @property - def description(self) -> str | None: - return self._description + def __init__(self, id: str, name: str | None = None, description: str | None = None): + self.id = id + self.name = name + self.description = description - async def run( + @overload + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: Literal[False] = ..., + session: AgentSession | None = None, + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... + @overload + def run( self, - messages: str | Message | list[str] | list[Message] | None = None, + messages: AgentRunInputs | None = None, + *, + stream: Literal[True], + session: AgentSession | None = None, + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = None, *, stream: bool = False, - thread: AgentSession | None = None, + session: AgentSession | None = None, **kwargs: Any, - ) -> AgentResponse | AsyncIterable[AgentResponseUpdate]: + ) -> Awaitable[AgentResponse[Any]] | AsyncIterable[AgentResponseUpdate]: """Dummy run method.""" if stream: return self._run_stream_impl() - return AgentResponse(messages=[Message(role="assistant", text="Test response")]) + + async def _run() -> AgentResponse[Any]: + return AgentResponse(messages=[Message(role="assistant", text="Test response")]) + + return _run() async def _run_stream_impl(self) -> AsyncIterable[AgentResponseUpdate]: - yield AgentResponseUpdate(messages=[Message(role="assistant", text="Test response stream")]) + yield AgentResponseUpdate(contents=[Content.from_text(text="Test response stream")]) def create_session(self, **kwargs: Any) -> AgentSession: """Creates a new conversation session for the agent.""" return AgentSession(**kwargs) + def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession: + """Gets or creates a session for a service-managed session ID.""" + return AgentSession(**kwargs) + class TestAgentApprovalExecutor: """Tests for AgentApprovalExecutor.""" diff --git a/python/packages/orchestrations/tests/test_sequential.py b/python/packages/orchestrations/tests/test_sequential.py index 67bcc1bb9e..89485a65a7 100644 --- a/python/packages/orchestrations/tests/test_sequential.py +++ b/python/packages/orchestrations/tests/test_sequential.py @@ -1,18 +1,20 @@ # Copyright (c) Microsoft. All rights reserved. from collections.abc import AsyncIterable, Awaitable -from typing import Any +from typing import Any, Literal, overload import pytest from agent_framework import ( AgentExecutorResponse, AgentResponse, AgentResponseUpdate, + AgentRunInputs, AgentSession, BaseAgent, Content, Executor, Message, + ResponseStream, TypeCompatibilityError, WorkflowContext, WorkflowRunState, @@ -21,20 +23,42 @@ from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage from agent_framework.orchestrations import SequentialBuilder +from agent_framework_orchestrations._orchestration_shared import OrchestrationOutput + class _EchoAgent(BaseAgent): """Simple agent that appends a single assistant message with its name.""" - def run( # type: ignore[override] + @overload + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... + + @overload + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + + def run( self, - messages: str | Message | list[str] | list[Message] | None = None, + messages: AgentRunInputs | None = None, *, stream: bool = False, session: AgentSession | None = None, **kwargs: Any, - ) -> Awaitable[AgentResponse] | AsyncIterable[AgentResponseUpdate]: + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: if stream: - return self._run_stream() + return ResponseStream(self._run_stream(), finalizer=AgentResponse.from_updates) async def _run() -> AgentResponse: return AgentResponse(messages=[Message("assistant", [f"{self.name} reply"])]) @@ -84,7 +108,7 @@ async def test_sequential_agents_append_to_context() -> None: wf = SequentialBuilder(participants=[a1, a2]).build() completed = False - output: list[Message] | None = None + output: OrchestrationOutput | None = None async for ev in wf.run("hello sequential", stream=True): if ev.type == "status" and ev.state == WorkflowRunState.IDLE: completed = True @@ -95,8 +119,8 @@ async def test_sequential_agents_append_to_context() -> None: assert completed assert output is not None - assert isinstance(output, list) - msgs: list[Message] = output + assert isinstance(output, OrchestrationOutput) + msgs: list[Message] = output.messages assert len(msgs) == 3 assert msgs[0].role == "user" and "hello sequential" in msgs[0].text assert msgs[1].role == "assistant" and (msgs[1].author_name == "A1" or True) @@ -112,7 +136,7 @@ async def test_sequential_with_custom_executor_summary() -> None: wf = SequentialBuilder(participants=[a1, summarizer]).build() completed = False - output: list[Message] | None = None + output: OrchestrationOutput | None = None async for ev in wf.run("topic X", stream=True): if ev.type == "status" and ev.state == WorkflowRunState.IDLE: completed = True @@ -123,7 +147,7 @@ async def test_sequential_with_custom_executor_summary() -> None: assert completed assert output is not None - msgs: list[Message] = output + msgs: list[Message] = output.messages # Expect: [user, A1 reply, summary] assert len(msgs) == 3 assert msgs[0].role == "user" @@ -137,7 +161,7 @@ async def test_sequential_checkpoint_resume_round_trip() -> None: initial_agents = (_EchoAgent(id="agent1", name="A1"), _EchoAgent(id="agent2", name="A2")) wf = SequentialBuilder(participants=list(initial_agents), checkpoint_storage=storage).build() - baseline_output: list[Message] | None = None + baseline_output: OrchestrationOutput | None = None async for ev in wf.run("checkpoint sequential", stream=True): if ev.type == "output": baseline_output = ev.data # type: ignore[assignment] @@ -154,7 +178,7 @@ async def test_sequential_checkpoint_resume_round_trip() -> None: resumed_agents = (_EchoAgent(id="agent1", name="A1"), _EchoAgent(id="agent2", name="A2")) wf_resume = SequentialBuilder(participants=list(resumed_agents), checkpoint_storage=storage).build() - resumed_output: list[Message] | None = None + resumed_output: OrchestrationOutput | None = None async for ev in wf_resume.run(checkpoint_id=resume_checkpoint.checkpoint_id, stream=True): if ev.type == "output": resumed_output = ev.data # type: ignore[assignment] @@ -165,8 +189,8 @@ async def test_sequential_checkpoint_resume_round_trip() -> None: break assert resumed_output is not None - assert [m.role for m in resumed_output] == [m.role for m in baseline_output] - assert [m.text for m in resumed_output] == [m.text for m in baseline_output] + assert [m.role for m in resumed_output.messages] == [m.role for m in baseline_output.messages] + assert [m.text for m in resumed_output.messages] == [m.text for m in baseline_output.messages] async def test_sequential_checkpoint_runtime_only() -> None: @@ -176,7 +200,7 @@ async def test_sequential_checkpoint_runtime_only() -> None: agents = (_EchoAgent(id="agent1", name="A1"), _EchoAgent(id="agent2", name="A2")) wf = SequentialBuilder(participants=list(agents)).build() - baseline_output: list[Message] | None = None + baseline_output: OrchestrationOutput | None = None async for ev in wf.run("runtime checkpoint test", checkpoint_storage=storage, stream=True): if ev.type == "output": baseline_output = ev.data # type: ignore[assignment] @@ -193,7 +217,7 @@ async def test_sequential_checkpoint_runtime_only() -> None: resumed_agents = (_EchoAgent(id="agent1", name="A1"), _EchoAgent(id="agent2", name="A2")) wf_resume = SequentialBuilder(participants=list(resumed_agents)).build() - resumed_output: list[Message] | None = None + resumed_output: OrchestrationOutput | None = None async for ev in wf_resume.run( checkpoint_id=resume_checkpoint.checkpoint_id, checkpoint_storage=storage, stream=True ): @@ -206,8 +230,8 @@ async def test_sequential_checkpoint_runtime_only() -> None: break assert resumed_output is not None - assert [m.role for m in resumed_output] == [m.role for m in baseline_output] - assert [m.text for m in resumed_output] == [m.text for m in baseline_output] + assert [m.role for m in resumed_output.messages] == [m.role for m in baseline_output.messages] + assert [m.text for m in resumed_output.messages] == [m.text for m in baseline_output.messages] async def test_sequential_checkpoint_runtime_overrides_buildtime() -> None: @@ -223,7 +247,7 @@ async def test_sequential_checkpoint_runtime_overrides_buildtime() -> None: agents = (_EchoAgent(id="agent1", name="A1"), _EchoAgent(id="agent2", name="A2")) wf = SequentialBuilder(participants=list(agents), checkpoint_storage=buildtime_storage).build() - baseline_output: list[Message] | None = None + baseline_output: OrchestrationOutput | None = None async for ev in wf.run("override test", checkpoint_storage=runtime_storage, stream=True): if ev.type == "output": baseline_output = ev.data # type: ignore[assignment]