diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative.Mcp/DefaultMcpToolHandler.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative.Mcp/DefaultMcpToolHandler.cs index 107f4f0260..0abc02122d 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative.Mcp/DefaultMcpToolHandler.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative.Mcp/DefaultMcpToolHandler.cs @@ -61,10 +61,18 @@ public async Task InvokeToolAsync( ? null : arguments as IReadOnlyDictionary ?? new Dictionary(arguments); - CallToolResult result = await client.CallToolAsync( - toolName, - readOnlyArguments, - cancellationToken: cancellationToken).ConfigureAwait(false); + CallToolResult result; + try + { + result = await client.CallToolAsync( + toolName, + readOnlyArguments, + cancellationToken: cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) when (TryClassifyToolInvocationFailure(ex.Message, out string? failureCode)) + { + throw new InvalidOperationException($"[{failureCode}] {ex.Message}", ex); + } // Map MCP content blocks to MEAI AIContent types PopulateResultContent(resultContent, result); @@ -183,6 +191,35 @@ private static string ComputeHeadersHash(IDictionary? headers) return hashCode.ToString(CultureInfo.InvariantCulture); } + internal static bool TryClassifyToolInvocationFailure(string? message, out string? failureCode) + { + if (string.IsNullOrWhiteSpace(message)) + { + failureCode = null; + return false; + } + + string normalized = message.ToLowerInvariant(); + if (normalized.Contains("tool not found", StringComparison.Ordinal) || + normalized.Contains("unknown tool", StringComparison.Ordinal) || + normalized.Contains("no tool named", StringComparison.Ordinal)) + { + failureCode = "mcp_tool_missing"; + return true; + } + + if (normalized.Contains("invalid params", StringComparison.Ordinal) || + normalized.Contains("schema", StringComparison.Ordinal) || + normalized.Contains("validation", StringComparison.Ordinal)) + { + failureCode = "mcp_tool_schema_mismatch"; + return true; + } + + failureCode = null; + return false; + } + private static void PopulateResultContent(McpServerToolResultContent resultContent, CallToolResult result) { // Ensure Output list is initialized diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.Mcp.UnitTests/DefaultMcpToolHandlerTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.Mcp.UnitTests/DefaultMcpToolHandlerTests.cs index abfa95cc36..f558005c81 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.Mcp.UnitTests/DefaultMcpToolHandlerTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.Mcp.UnitTests/DefaultMcpToolHandlerTests.cs @@ -488,5 +488,28 @@ public void ConvertContentBlock_AudioContentBlock_WithNullMimeType_ShouldDefault dataContent.MediaType.Should().Be("audio/*"); } + [Theory] + [InlineData("Tool not found on remote server", "mcp_tool_missing")] + [InlineData("Invalid params: schema changed", "mcp_tool_schema_mismatch")] + [InlineData("Request failed validation", "mcp_tool_schema_mismatch")] + public void TryClassifyToolInvocationFailure_WithKnownSchemaOrToolMessages_ReturnsStableCode( + string message, + string expectedCode) + { + bool classified = DefaultMcpToolHandler.TryClassifyToolInvocationFailure(message, out string? code); + + classified.Should().BeTrue(); + code.Should().Be(expectedCode); + } + + [Fact] + public void TryClassifyToolInvocationFailure_WithUnrelatedMessage_ReturnsFalse() + { + bool classified = DefaultMcpToolHandler.TryClassifyToolInvocationFailure("Socket closed unexpectedly", out string? code); + + classified.Should().BeFalse(); + code.Should().BeNull(); + } + #endregion } diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index 5901e34dd9..e185394477 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -286,6 +286,17 @@ def _parse_content_from_mcp( return return_types +def _classify_mcp_tool_failure(message: str) -> str | None: + lowered = message.lower() + + if "tool not found" in lowered or "unknown tool" in lowered or "no tool named" in lowered: + return "mcp_tool_missing" + if "invalid params" in lowered or "schema" in lowered or "validation" in lowered: + return "mcp_tool_schema_mismatch" + + return None + + def _prepare_content_for_mcp( content: Content, ) -> types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource | types.ResourceLink | None: @@ -637,6 +648,9 @@ async def _connect_on_owner(self, *, reset: bool = False) -> None: self.session = None self.is_connected = False self._exit_stack = AsyncExitStack() + self._functions = [] + self._tools_loaded = False + self._prompts_loaded = False if not self.session: try: transport = await self._exit_stack.enter_async_context(self.get_mcp_client()) @@ -1054,6 +1068,18 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]: inner_exception=cl_ex, ) from cl_ex except McpError as mcp_exc: + failure_code = _classify_mcp_tool_failure(mcp_exc.error.message) + if failure_code is not None: + try: + await self.connect(reset=True) + except Exception: + logger.debug( + "Failed to refresh MCP tool definitions after classified tool failure.", exc_info=True + ) + raise ToolExecutionException( + f"[{failure_code}] {mcp_exc.error.message}", + inner_exception=mcp_exc, + ) from mcp_exc raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc except Exception as ex: raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index b29ec1a794..d1f99e08a2 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -23,6 +23,7 @@ ) from agent_framework._mcp import ( MCPTool, + _classify_mcp_tool_failure, _get_input_model_from_mcp_prompt, _normalize_mcp_name, _parse_content_from_mcp, @@ -53,6 +54,12 @@ def test_normalize_mcp_name(): assert _normalize_mcp_name("name/with\\slashes") == "name-with-slashes" +def test_classify_mcp_tool_failure(): + assert _classify_mcp_tool_failure("Tool not found on remote server") == "mcp_tool_missing" + assert _classify_mcp_tool_failure("Invalid params for schema validation") == "mcp_tool_schema_mismatch" + assert _classify_mcp_tool_failure("transport closed") is None + + def test_mcp_transport_subclasses_accept_tool_name_prefix() -> None: assert MCPStdioTool(name="stdio", command="python", tool_name_prefix="stdio").tool_name_prefix == "stdio" assert ( @@ -1032,6 +1039,49 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: await func.invoke(param="test_value") +async def test_local_mcp_server_schema_drift_error_is_classified_and_refreshes(): + """Schema drift should fail closed with a stable marker and trigger a tool refresh.""" + + class TestServer(MCPTool): + async def connect(self): + self.session = Mock(spec=ClientSession) + self.session.list_tools = AsyncMock( + return_value=types.ListToolsResult( + tools=[ + types.Tool( + name="test_tool", + description="Test tool", + inputSchema={ + "type": "object", + "properties": {"param": {"type": "string"}}, + "required": ["param"], + }, + ) + ] + ) + ) + self.session.call_tool = AsyncMock( + side_effect=McpError(types.ErrorData(code=-32602, message="Invalid params: schema changed")) + ) + + def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: + return None + + server = TestServer(name="test_server") + async with server: + await server.load_tools() + func = server.functions[0] + + with ( + patch.object(server, "connect", new_callable=AsyncMock) as mock_connect, + pytest.raises(ToolExecutionException, match=r"\[mcp_tool_schema_mismatch\]") as exc_info, + ): + await func.invoke(param="test_value") + + mock_connect.assert_awaited_once_with(reset=True) + assert "schema changed" in str(exc_info.value) + + async def test_mcp_tool_call_tool_raises_on_is_error(): """Test that call_tool raises ToolExecutionException when MCP returns isError=True."""