Skip to content
Open
21 changes: 20 additions & 1 deletion src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,12 @@ async def push_notification_callback() -> None:

except Exception:
logger.exception('Agent execution failed')
# If the consumer fails, we must cancel the producer to prevent it from hanging
# on queue operations (e.g., waiting for the queue to drain).
producer_task.cancel()
# Force the queue to close immediately, discarding any pending events.
# This ensures that any producers waiting on the queue are unblocked.
await queue.close(immediate=True)
raise
finally:
if interrupted_or_non_blocking:
Expand Down Expand Up @@ -392,6 +398,12 @@ async def on_message_send_stream(
bg_task.set_name(f'background_consume:{task_id}')
self._track_background_task(bg_task)
raise
except Exception:
# If the consumer fails (e.g. database error), we must cleanup.
logger.exception('Agent execution failed during streaming')
producer_task.cancel()
await queue.close(immediate=True)
raise
finally:
cleanup_task = asyncio.create_task(
self._cleanup_producer(producer_task, task_id)
Expand Down Expand Up @@ -435,7 +447,14 @@ async def _cleanup_producer(
task_id: str,
) -> None:
"""Cleans up the agent execution task and queue manager entry."""
await producer_task
try:
await producer_task
except asyncio.CancelledError:
logger.debug(
'Producer task %s was cancelled during cleanup', task_id
)
except Exception:
logger.exception('Producer task %s failed during cleanup', task_id)
await self._queue_manager.close(task_id)
async with self._running_agents_lock:
self._running_agents.pop(task_id, None)
Expand Down
88 changes: 88 additions & 0 deletions tests/server/request_handlers/test_default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2644,3 +2644,91 @@ async def test_on_message_send_stream_task_id_provided_but_task_not_found():
f'Task {task_id} was specified but does not exist'
in exc_info.value.error.message
)


@pytest.mark.asyncio
async def test_on_message_send_stream_consumer_error_cancels_producer_and_closes_queue():
"""Test that if the consumer (result aggregator) raises an exception, the producer is cancelled and queue is closed immediately."""
mock_task_store = AsyncMock(spec=TaskStore)
mock_queue_manager = AsyncMock(spec=QueueManager)
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_request_context_builder = AsyncMock(spec=RequestContextBuilder)

task_id = 'error_cleanup_task'
context_id = 'error_cleanup_ctx'

mock_request_context = MagicMock(spec=RequestContext)
mock_request_context.task_id = task_id
mock_request_context.context_id = context_id
mock_request_context_builder.build.return_value = mock_request_context

mock_queue = AsyncMock(spec=EventQueue)
mock_queue_manager.create_or_tap.return_value = mock_queue

request_handler = DefaultRequestHandler(
agent_executor=mock_agent_executor,
task_store=mock_task_store,
queue_manager=mock_queue_manager,
request_context_builder=mock_request_context_builder,
)

params = MessageSendParams(
message=Message(
role=Role.user,
message_id='msg_error_cleanup',
parts=[],
# Do NOT provide task_id here to avoid "Task ... was specified but does not exist" error
)
)

# Mock ResultAggregator to raise exception
mock_result_aggregator_instance = MagicMock(spec=ResultAggregator)

async def raise_error_gen(_consumer):
# Raise an exception to simulate consumer failure
raise ValueError('Consumer failed!')
yield # unreachable

mock_result_aggregator_instance.consume_and_emit.side_effect = (
raise_error_gen
)

# Capture the producer task to verify cancellation
captured_producer_task = None
original_register = request_handler._register_producer

async def spy_register_producer(tid, task):
nonlocal captured_producer_task
captured_producer_task = task
# Wrap the cancel method to spy on it
task.cancel = MagicMock(wraps=task.cancel)
await original_register(tid, task)

with (
patch(
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
return_value=mock_result_aggregator_instance,
),
patch(
'a2a.server.request_handlers.default_request_handler.TaskManager.get_task',
return_value=None,
),
patch.object(
request_handler,
'_register_producer',
side_effect=spy_register_producer,
),
):
# Act
with pytest.raises(ValueError, match='Consumer failed!'):
async for _ in request_handler.on_message_send_stream(
params, create_server_call_context()
):
pass

assert captured_producer_task is not None
# Verify producer was cancelled
captured_producer_task.cancel.assert_called()

# Verify queue closed immediately
mock_queue.close.assert_awaited_with(immediate=True)
1 change: 0 additions & 1 deletion tests/server/request_handlers/test_jsonrpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,6 @@ async def streaming_coro():

self.assertIsInstance(response.root, JSONRPCErrorResponse)
assert response.root.error == UnsupportedOperationError() # type: ignore
mock_agent_executor.execute.assert_called_once()

@patch(
'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'
Expand Down
Loading