Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@
# Orchestration helpers
"AgentRequestInfoResponse",
"OrchestrationState",
"clean_conversation_for_handoff",
"create_completion_message",
# Group Chat
"AgentBasedGroupChatOrchestrator",
"AgentOrchestrationOutput",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@
StandardMagenticManager,
)
from ._orchestration_request_info import AgentRequestInfoResponse
from ._orchestration_state import OrchestrationState
from ._orchestrator_helpers import clean_conversation_for_handoff, create_completion_message
from ._orchestration_shared import OrchestrationState
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OrchestrationOutput is the primary type SDK consumers need for isinstance checks and type hints on workflow results, but it is never imported here. Add it to the imports.

Suggested change
from ._orchestration_shared import OrchestrationState
from ._orchestration_shared import OrchestrationOutput, OrchestrationState

from ._sequential import SequentialBuilder

__all__ = [
Expand Down Expand Up @@ -105,6 +104,4 @@
"StandardMagenticManager",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OrchestrationOutput is not included in all. Add it so users can import it from the public package path rather than the private _orchestration_shared module.

Suggested change
"StandardMagenticManager",
"OrchestrationOutput",
"StandardMagenticManager",
"TerminationCondition",
"__version__",

"TerminationCondition",
"__version__",
"clean_conversation_for_handoff",
"create_completion_message",
]
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing_extensions import Never

from ._orchestration_request_info import AgentApprovalExecutor
from ._orchestration_shared import OrchestrationOutput

if sys.version_info >= (3, 12):
from typing import override # type: ignore # pragma: no cover
Expand Down Expand Up @@ -351,7 +352,7 @@ async def _check_termination(self) -> bool:
result = await result
return result

async def _check_terminate_and_yield(self, ctx: WorkflowContext[Never, list[Message]]) -> bool:
async def _check_terminate_and_yield(self, ctx: WorkflowContext[Never, OrchestrationOutput]) -> bool:
"""Check termination conditions and yield completion if met.

Args:
Expand All @@ -363,7 +364,7 @@ async def _check_terminate_and_yield(self, ctx: WorkflowContext[Never, list[Mess
terminate = await self._check_termination()
if terminate:
self._append_messages([self._create_completion_message(self.TERMINATION_CONDITION_MET_MESSAGE)])
await ctx.yield_output(self._full_conversation)
await ctx.yield_output(OrchestrationOutput(messages=self._full_conversation))
return True

return False
Expand Down Expand Up @@ -490,7 +491,7 @@ def _check_round_limit(self) -> bool:

return False

async def _check_round_limit_and_yield(self, ctx: WorkflowContext[Never, list[Message]]) -> bool:
async def _check_round_limit_and_yield(self, ctx: WorkflowContext[Never, OrchestrationOutput]) -> bool:
"""Check round limit and yield completion if reached.

Args:
Expand All @@ -502,7 +503,7 @@ async def _check_round_limit_and_yield(self, ctx: WorkflowContext[Never, list[Me
reach_max_rounds = self._check_round_limit()
if reach_max_rounds:
self._append_messages([self._create_completion_message(self.MAX_ROUNDS_MET_MESSAGE)])
await ctx.yield_output(self._full_conversation)
await ctx.yield_output(OrchestrationOutput(messages=self._full_conversation))
return True

return False
Expand All @@ -521,7 +522,7 @@ async def on_checkpoint_save(self) -> dict[str, Any]:
Returns:
Serialized state dict
"""
from ._orchestration_state import OrchestrationState
from ._orchestration_shared import OrchestrationState

state = OrchestrationState(
conversation=list(self._full_conversation),
Expand Down Expand Up @@ -551,7 +552,7 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
Args:
state: Serialized state dict
"""
from ._orchestration_state import OrchestrationState
from ._orchestration_shared import OrchestrationState

orch_state = OrchestrationState.from_dict(state)
self._full_conversation = list(orch_state.conversation)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# Copyright (c) Microsoft. All rights reserved.

import asyncio
import inspect
import logging
from collections.abc import Callable, Sequence
from typing import Any
from typing import Any, Awaitable

from agent_framework import Message, SupportsAgentRun
from agent_framework._workflows._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse
Expand All @@ -18,6 +17,7 @@
from typing_extensions import Never

from ._orchestration_request_info import AgentApprovalExecutor
from ._orchestration_shared import OrchestrationOutput

logger = logging.getLogger(__name__)

Expand All @@ -32,13 +32,11 @@
- Participants can be provided as SupportsAgentRun or Executor instances via `participants=[...]`.
- A custom aggregator can be provided as:
- an Executor instance (it should handle list[AgentExecutorResponse],
yield output), or
yield OrchestrationOutput), or
- a callback function with signature:
def cb(results: list[AgentExecutorResponse]) -> Any | None
def cb(results: list[AgentExecutorResponse], ctx: WorkflowContext) -> Any | None
The callback is wrapped in _CallbackAggregator.
If the callback returns a non-None value, _CallbackAggregator yields that as output.
If it returns None, the callback may have already yielded an output via ctx, so no further action is taken.
def cb(results: list[AgentExecutorResponse]) -> list[Message]
The callback is wrapped in _CallbackAggregator, which wraps the returned
list[Message] in an OrchestrationOutput and yields it.
"""


Expand Down Expand Up @@ -82,7 +80,9 @@ class _AggregateAgentConversations(Executor):
"""

@handler
async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, list[Message]]) -> None:
async def aggregate(
self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, OrchestrationOutput]
) -> None:
if not results:
logger.error("Concurrent aggregator received empty results list")
raise ValueError("Aggregation failed: no results provided")
Expand Down Expand Up @@ -137,47 +137,43 @@ def _is_role(msg: Any, role: str) -> bool:
logger.warning("No user prompt found in any conversation; emitting assistants only")
output.extend(assistant_replies)

await ctx.yield_output(output)
await ctx.yield_output(OrchestrationOutput(messages=output))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stale docstring: still says 'Sync callbacks are executed via asyncio.to_thread to avoid blocking the event loop' but the implementation no longer does this. Update the docstring to match the actual behavior, or (better) restore the asyncio.to_thread wrapping.



class _CallbackAggregator(Executor):
"""Wraps a Python callback as an aggregator.

Accepts either an async or sync callback with one of the signatures:
- (results: list[AgentExecutorResponse]) -> Any | None
- (results: list[AgentExecutorResponse], ctx: WorkflowContext[Any]) -> Any | None
Accepts either an async or sync callback with the following signature:
- (results: list[AgentExecutorResponse]) -> list[Message]

Comment on lines 145 to +148
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: sync callbacks are now called directly on the event loop thread instead of via asyncio.to_thread, blocking all concurrent async tasks. The asyncio import was removed entirely. Restore the asyncio.to_thread path for non-awaitable callbacks.

Suggested change
Accepts either an async or sync callback with one of the signatures:
- (results: list[AgentExecutorResponse]) -> Any | None
- (results: list[AgentExecutorResponse], ctx: WorkflowContext[Any]) -> Any | None
Accepts either an async or sync callback with the following signature:
- (results: list[AgentExecutorResponse]) -> list[Message]
if inspect.iscoroutinefunction(self._callback):
ret = await self._callback(results)
else:
import asyncio
ret = await asyncio.to_thread(self._callback, results)
await ctx.yield_output(OrchestrationOutput(messages=list(ret)))

The returned list[Message] is automatically wrapped in an OrchestrationOutput.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the callback returns None at runtime (e.g., a function with no return statement), list(ret) raises TypeError. The old code had an explicit None guard. Add validation to produce a clear error message.

Suggested change
if ret is None:
raise TypeError("Aggregator callback must return list[Message], got None")
await ctx.yield_output(OrchestrationOutput(messages=list(ret)))

Notes:
- Async callbacks are awaited directly.
- Sync callbacks are executed via asyncio.to_thread to avoid blocking the event loop.
- If the callback returns a non-None value, it is yielded as an output.
- If the callback returns a non-None value, it is wrapped in OrchestrationOutput and yielded.
"""

def __init__(self, callback: Callable[..., Any], id: str | None = None) -> None:
def __init__(
self,
callback: Callable[[list[AgentExecutorResponse]], Awaitable[list[Message]] | list[Message]],
id: str | None = None,
) -> None:
derived_id = getattr(callback, "__name__", "") or ""
if not derived_id or derived_id == "<lambda>":
derived_id = f"{type(self).__name__}_unnamed"
super().__init__(id or derived_id)
self._callback = callback
self._param_count = len(inspect.signature(callback).parameters)

@handler
async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, Any]) -> None:
# Call according to provided signature, always non-blocking for sync callbacks
if self._param_count >= 2:
if inspect.iscoroutinefunction(self._callback):
ret = await self._callback(results, ctx) # type: ignore[misc]
else:
ret = await asyncio.to_thread(self._callback, results, ctx)
else:
if inspect.iscoroutinefunction(self._callback):
ret = await self._callback(results) # type: ignore[misc]
else:
ret = await asyncio.to_thread(self._callback, results)
async def aggregate(
self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, OrchestrationOutput]
) -> None:
ret = self._callback(results)
if inspect.isawaitable(ret):
ret = await ret

# If the callback returned a value, finalize the workflow with it
if ret is not None:
await ctx.yield_output(ret)
await ctx.yield_output(OrchestrationOutput(messages=list(ret)))


class ConcurrentBuilder:
Expand All @@ -193,7 +189,7 @@ class ConcurrentBuilder:

from agent_framework_orchestrations import ConcurrentBuilder

# Minimal: use default aggregator (returns list[Message])
# Minimal: use default aggregator (returns OrchestrationOutput)
workflow = ConcurrentBuilder(participants=[agent1, agent2, agent3]).build()


Expand Down Expand Up @@ -265,17 +261,14 @@ def _set_participants(self, participants: Sequence[SupportsAgentRun | Executor])

def with_aggregator(
self,
aggregator: Executor
| Callable[[list[AgentExecutorResponse]], Any]
| Callable[[list[AgentExecutorResponse], WorkflowContext[Never, Any]], Any],
aggregator: Executor | Callable[[list[AgentExecutorResponse]], Awaitable[list[Message]] | list[Message]],
) -> "ConcurrentBuilder":
r"""Override the default aggregator with an executor or a callback.

- Executor: must handle `list[AgentExecutorResponse]` and yield output using `ctx.yield_output(...)`
- Callback: sync or async callable with one of the signatures:
`(results: list[AgentExecutorResponse]) -> Any | None` or
`(results: list[AgentExecutorResponse], ctx: WorkflowContext) -> Any | None`.
If the callback returns a non-None value, it becomes the workflow's output.
- Callback: sync or async callable with the signature:
`(results: list[AgentExecutorResponse]) -> list[Message]`.
The returned list[Message] is automatically wrapped in an OrchestrationOutput.

Args:
aggregator: Executor instance, or callback function
Expand All @@ -287,23 +280,15 @@ def with_aggregator(
class CustomAggregator(Executor):
@handler
async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext) -> None:
await ctx.yield_output(" | ".join(r.agent_response.messages[-1].text for r in results))
await ctx.yield_output(OrchestrationOutput(messages=[...]))


wf = ConcurrentBuilder(participants=[a1, a2, a3]).with_aggregator(CustomAggregator()).build()


# Callback-based aggregator (string result)
async def summarize(results: list[AgentExecutorResponse]) -> str:
return " | ".join(r.agent_response.messages[-1].text for r in results)


wf = ConcurrentBuilder(participants=[a1, a2, a3]).with_aggregator(summarize).build()


# Callback-based aggregator (yield result)
async def summarize(results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, str]) -> None:
await ctx.yield_output(" | ".join(r.agent_response.messages[-1].text for r in results))
# Callback-based aggregator (returns list[Message], wrapped in OrchestrationOutput)
async def summarize(results: list[AgentExecutorResponse]) -> list[Message]:
return [r.agent_response.messages[-1] for r in results]


wf = ConcurrentBuilder(participants=[a1, a2, a3]).with_aggregator(summarize).build()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
TerminationCondition,
)
from ._orchestration_request_info import AgentApprovalExecutor
from ._orchestrator_helpers import clean_conversation_for_handoff
from ._orchestration_shared import OrchestrationOutput, filter_tool_contents

if sys.version_info >= (3, 12):
from typing import override # type: ignore # pragma: no cover
Expand Down Expand Up @@ -169,7 +169,7 @@ async def _handle_messages(
"""Initialize orchestrator state and start the conversation loop."""
self._append_messages(messages)
# Termination condition will also be applied to the input messages
if await self._check_terminate_and_yield(cast(WorkflowContext[Never, list[Message]], ctx)):
if await self._check_terminate_and_yield(cast(WorkflowContext[Never, OrchestrationOutput], ctx)):
return

next_speaker = await self._get_next_speaker()
Expand All @@ -195,12 +195,12 @@ async def _handle_response(
"""Handle a participant response."""
messages = self._process_participant_response(response)
# Remove tool-related content to prevent API errors from empty messages
messages = clean_conversation_for_handoff(messages)
messages = filter_tool_contents(messages)
self._append_messages(messages)

if await self._check_terminate_and_yield(cast(WorkflowContext[Never, list[Message]], ctx)):
if await self._check_terminate_and_yield(cast(WorkflowContext[Never, OrchestrationOutput], ctx)):
return
if await self._check_round_limit_and_yield(cast(WorkflowContext[Never, list[Message]], ctx)):
if await self._check_round_limit_and_yield(cast(WorkflowContext[Never, OrchestrationOutput], ctx)):
return

next_speaker = await self._get_next_speaker()
Expand Down Expand Up @@ -332,13 +332,13 @@ async def _handle_messages(
"""Initialize orchestrator state and start the conversation loop."""
self._append_messages(messages)
# Termination condition will also be applied to the input messages
if await self._check_terminate_and_yield(cast(WorkflowContext[Never, list[Message]], ctx)):
if await self._check_terminate_and_yield(cast(WorkflowContext[Never, OrchestrationOutput], ctx)):
return

agent_orchestration_output = await self._invoke_agent()
if await self._check_agent_terminate_and_yield(
agent_orchestration_output,
cast(WorkflowContext[Never, list[Message]], ctx),
cast(WorkflowContext[Never, OrchestrationOutput], ctx),
):
return

Expand All @@ -364,17 +364,17 @@ async def _handle_response(
"""Handle a participant response."""
messages = self._process_participant_response(response)
# Remove tool-related content to prevent API errors from empty messages
messages = clean_conversation_for_handoff(messages)
messages = filter_tool_contents(messages)
self._append_messages(messages)
if await self._check_terminate_and_yield(cast(WorkflowContext[Never, list[Message]], ctx)):
if await self._check_terminate_and_yield(cast(WorkflowContext[Never, OrchestrationOutput], ctx)):
return
if await self._check_round_limit_and_yield(cast(WorkflowContext[Never, list[Message]], ctx)):
if await self._check_round_limit_and_yield(cast(WorkflowContext[Never, OrchestrationOutput], ctx)):
return

agent_orchestration_output = await self._invoke_agent()
if await self._check_agent_terminate_and_yield(
agent_orchestration_output,
cast(WorkflowContext[Never, list[Message]], ctx),
cast(WorkflowContext[Never, OrchestrationOutput], ctx),
):
return

Expand Down Expand Up @@ -522,7 +522,7 @@ async def _invoke_agent_helper(conversation: list[Message]) -> AgentOrchestratio
async def _check_agent_terminate_and_yield(
self,
agent_orchestration_output: AgentOrchestrationOutput,
ctx: WorkflowContext[Never, list[Message]],
ctx: WorkflowContext[Never, OrchestrationOutput],
) -> bool:
"""Check if the agent requested termination and yield completion if so.

Expand All @@ -537,7 +537,7 @@ async def _check_agent_terminate_and_yield(
agent_orchestration_output.final_message or "The conversation has been terminated by the agent."
)
self._append_messages([self._create_completion_message(final_message)])
await ctx.yield_output(self._full_conversation)
await ctx.yield_output(OrchestrationOutput(messages=self._full_conversation))
return True

return False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from agent_framework._workflows._workflow_context import WorkflowContext

from ._base_group_chat_orchestrator import TerminationCondition
from ._orchestrator_helpers import clean_conversation_for_handoff
from ._orchestration_shared import OrchestrationOutput, filter_tool_contents

if sys.version_info >= (3, 12):
from typing import override # type: ignore # pragma: no cover
Expand Down Expand Up @@ -428,7 +428,7 @@ def _handoff_tool() -> None:
async def _run_agent_and_emit(self, ctx: WorkflowContext[Any, Any]) -> None:
"""Override to support handoff."""
incoming_messages = list(self._cache)
cleaned_incoming_messages = clean_conversation_for_handoff(incoming_messages)
cleaned_incoming_messages = filter_tool_contents(incoming_messages)
runtime_tool_messages = [
message
for message in incoming_messages
Expand Down Expand Up @@ -492,7 +492,7 @@ async def _run_agent_and_emit(self, ctx: WorkflowContext[Any, Any]) -> None:

# Remove function call related content from the agent response for broadcast.
# This prevents replaying stale tool artifacts to other agents.
cleaned_response = clean_conversation_for_handoff(response.messages)
cleaned_response = filter_tool_contents(response.messages)

# For internal tracking, preserve the full response (including function_calls)
# in _full_conversation so that Azure OpenAI can match function_calls with
Expand Down Expand Up @@ -624,7 +624,7 @@ def _is_handoff_requested(self, response: AgentResponse) -> str | None:

return None

async def _check_terminate_and_yield(self, ctx: WorkflowContext[Any, Any]) -> bool:
async def _check_terminate_and_yield(self, ctx: WorkflowContext[Any, OrchestrationOutput]) -> bool:
"""Check termination conditions and yield completion if met.

Args:
Expand All @@ -641,7 +641,7 @@ async def _check_terminate_and_yield(self, ctx: WorkflowContext[Any, Any]) -> bo
terminated = await terminated

if terminated:
await ctx.yield_output(self._full_conversation)
await ctx.yield_output(OrchestrationOutput(messages=self._full_conversation))
return True

return False
Expand Down
Loading
Loading