From a7bf127b2e0077daad7384e16edb3b3c0e95e3a9 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Feb 2026 11:15:26 +0100 Subject: [PATCH 01/28] PR2: Wire context provider pipeline and update all internal consumers - Replace AgentThread with AgentSession across all packages - Replace ContextProvider with BaseContextProvider across all packages - Replace context_provider param with context_providers (Sequence) - Replace thread= with session= in run() signatures - Replace get_new_thread() with create_session() - Add get_session(service_session_id) to agent interface - DurableAgentThread -> DurableAgentSession - Remove _notify_thread_of_new_messages from WorkflowAgent - Wire before_run/after_run context provider pipeline in RawAgent - Auto-inject InMemoryHistoryProvider when no providers configured --- .../a2a/agent_framework_a2a/_agent.py | 10 +- .../ag-ui/agent_framework_ag_ui/_run.py | 8 +- .../__init__.py | 7 +- .../_context_provider.py | 16 +- .../test_aisearch_new_context_provider.py | 20 +- .../_agent_provider.py | 26 +- .../agent_framework_azure_ai/_chat_client.py | 12 +- .../agent_framework_azure_ai/_client.py | 12 +- .../_project_provider.py | 26 +- .../agent_framework_azurefunctions/_app.py | 4 +- .../_orchestration.py | 6 +- .../tests/test_orchestration.py | 34 +- .../claude/agent_framework_claude/_agent.py | 61 +-- .../agent_framework_copilotstudio/_agent.py | 44 +- .../copilotstudio/tests/test_copilot_agent.py | 22 +- .../packages/core/agent_framework/__init__.py | 1 + .../packages/core/agent_framework/_agents.py | 455 +++++++----------- .../packages/core/agent_framework/_clients.py | 12 +- .../core/agent_framework/_middleware.py | 26 +- .../core/agent_framework/_workflows/_agent.py | 74 +-- .../_workflows/_agent_executor.py | 62 ++- .../core/agent_framework/observability.py | 16 +- .../openai/_assistant_provider.py | 26 +- .../agent_framework_devui/_conversations.py | 20 +- .../devui/agent_framework_devui/_executor.py | 14 +- .../devui/tests/devui/test_conversations.py | 30 +- python/packages/durabletask/AGENTS.md | 2 +- .../agent_framework_durabletask/__init__.py | 4 +- .../agent_framework_durabletask/_executors.py | 38 +- .../agent_framework_durabletask/_models.py | 73 ++- .../agent_framework_durabletask/_shim.py | 16 +- .../test_01_dt_single_agent.py | 28 +- .../test_02_dt_multi_agent.py | 14 +- .../test_03_dt_single_agent_streaming.py | 36 +- .../tests/test_agent_session_id.py | 208 ++++---- .../packages/durabletask/tests/test_client.py | 20 +- .../durabletask/tests/test_executors.py | 52 +- .../tests/test_orchestration_context.py | 20 +- .../packages/durabletask/tests/test_shim.py | 48 +- .../agent_framework_github_copilot/_agent.py | 58 +-- .../mem0/agent_framework_mem0/__init__.py | 6 +- .../agent_framework_mem0/_context_provider.py | 17 +- .../tests/test_mem0_new_context_provider.py | 64 +-- .../_group_chat.py | 25 +- .../_handoff.py | 6 +- .../_magentic.py | 4 +- .../agent_framework_purview/_middleware.py | 6 +- .../redis/agent_framework_redis/__init__.py | 12 +- .../_context_provider.py | 17 +- .../_history_provider.py | 17 +- .../redis/tests/test_new_providers.py | 74 +-- 51 files changed, 866 insertions(+), 1043 deletions(-) diff --git a/python/packages/a2a/agent_framework_a2a/_agent.py b/python/packages/a2a/agent_framework_a2a/_agent.py index a938128f10..fa53d8d675 100644 --- a/python/packages/a2a/agent_framework_a2a/_agent.py +++ b/python/packages/a2a/agent_framework_a2a/_agent.py @@ -31,7 +31,7 @@ from agent_framework import ( AgentResponse, AgentResponseUpdate, - AgentThread, + AgentSession, BaseAgent, Content, ContinuationToken, @@ -211,7 +211,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: Literal[False] = ..., - thread: AgentThread | None = None, + session: AgentSession | None = None, continuation_token: A2AContinuationToken | None = None, background: bool = False, **kwargs: Any, @@ -223,7 +223,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: Literal[True], - thread: AgentThread | None = None, + session: AgentSession | None = None, continuation_token: A2AContinuationToken | None = None, background: bool = False, **kwargs: Any, @@ -234,7 +234,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, continuation_token: A2AContinuationToken | None = None, background: bool = False, **kwargs: Any, @@ -246,7 +246,7 @@ def run( Keyword Args: stream: Whether to stream the response. Defaults to False. - thread: The conversation thread associated with the message(s). + session: The conversation session associated with the message(s). continuation_token: Optional token to resume a long-running task instead of starting a new one. background: When True, in-progress task updates surface continuation diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_run.py b/python/packages/ag-ui/agent_framework_ag_ui/_run.py index c376120a5a..69eeba84ff 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_run.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_run.py @@ -27,7 +27,7 @@ ToolCallStartEvent, ) from agent_framework import ( - AgentThread, + AgentSession, Content, Message, SupportsAgentRun, @@ -812,9 +812,9 @@ async def run_agent_stream( # Create thread (with service thread support) if config.use_service_thread: supplied_thread_id = input_data.get("thread_id") or input_data.get("threadId") - thread = AgentThread(service_thread_id=supplied_thread_id) + thread = AgentSession(service_session_id=supplied_thread_id) else: - thread = AgentThread() + thread = AgentSession() # Inject metadata for AG-UI orchestration (Feature #2: Azure-safe truncation) base_metadata: dict[str, Any] = { @@ -826,7 +826,7 @@ async def run_agent_stream( thread.metadata = _build_safe_metadata(base_metadata) # type: ignore[attr-defined] # Build run kwargs (Feature #6: Azure store flag when metadata present) - run_kwargs: dict[str, Any] = {"thread": thread} + run_kwargs: dict[str, Any] = {"session": thread} if tools: run_kwargs["tools"] = tools # Filter out AG-UI internal metadata keys before passing to chat client diff --git a/python/packages/azure-ai-search/agent_framework_azure_ai_search/__init__.py b/python/packages/azure-ai-search/agent_framework_azure_ai_search/__init__.py index 4509c46d3e..7308f427c5 100644 --- a/python/packages/azure-ai-search/agent_framework_azure_ai_search/__init__.py +++ b/python/packages/azure-ai-search/agent_framework_azure_ai_search/__init__.py @@ -2,8 +2,8 @@ import importlib.metadata -from ._context_provider import _AzureAISearchContextProvider -from ._search_provider import AzureAISearchContextProvider, AzureAISearchSettings +from ._context_provider import AzureAISearchContextProvider +from ._search_provider import AzureAISearchSettings try: __version__ = importlib.metadata.version(__name__) @@ -11,8 +11,7 @@ __version__ = "0.0.0" # Fallback for development mode __all__ = [ - "AzureAISearchContextProvider", "AzureAISearchSettings", - "_AzureAISearchContextProvider", + "AzureAISearchContextProvider", "__version__", ] diff --git a/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py b/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py index 091695165d..127359372b 100644 --- a/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py +++ b/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py @@ -2,9 +2,8 @@ """New-pattern Azure AI Search context provider using BaseContextProvider. -This module provides ``_AzureAISearchContextProvider``, a side-by-side implementation of -:class:`AzureAISearchContextProvider` built on the new :class:`BaseContextProvider` hooks -pattern. It will replace the existing class in PR2. +This module provides ``AzureAISearchContextProvider``, built on the new +:class:`BaseContextProvider` hooks pattern. """ from __future__ import annotations @@ -111,16 +110,11 @@ _DEFAULT_AGENTIC_MESSAGE_HISTORY_COUNT = 10 -class _AzureAISearchContextProvider(BaseContextProvider): +class AzureAISearchContextProvider(BaseContextProvider): """Azure AI Search context provider using the new BaseContextProvider hooks pattern. Retrieves relevant context from Azure AI Search using semantic or agentic search - modes. This is the new-pattern equivalent of :class:`AzureAISearchContextProvider`. - - Note: - This class uses a temporary ``_`` prefix to coexist with the existing - :class:`AzureAISearchContextProvider`. It will replace the existing class - in PR2. + modes. """ _DEFAULT_SEARCH_CONTEXT_PROMPT: ClassVar[str] = "Use the following context to answer the question:" @@ -621,4 +615,4 @@ def _extract_document_text(self, doc: dict[str, Any], doc_id: str | None = None) return text -__all__ = ["_AzureAISearchContextProvider"] +__all__ = ["AzureAISearchContextProvider"] diff --git a/python/packages/azure-ai-search/tests/test_aisearch_new_context_provider.py b/python/packages/azure-ai-search/tests/test_aisearch_new_context_provider.py index e9af893273..8c18617e6e 100644 --- a/python/packages/azure-ai-search/tests/test_aisearch_new_context_provider.py +++ b/python/packages/azure-ai-search/tests/test_aisearch_new_context_provider.py @@ -9,7 +9,7 @@ from agent_framework._sessions import AgentSession, SessionContext from agent_framework.exceptions import ServiceInitializationError -from agent_framework_azure_ai_search._context_provider import _AzureAISearchContextProvider +from agent_framework_azure_ai_search._context_provider import AzureAISearchContextProvider # -- Helpers ------------------------------------------------------------------- @@ -56,7 +56,7 @@ async def _search(**kwargs): return client -def _make_provider(**overrides) -> _AzureAISearchContextProvider: +def _make_provider(**overrides) -> AzureAISearchContextProvider: """Create a semantic-mode provider with mocked internals (skips auto-discovery).""" defaults = { "source_id": "aisearch", @@ -65,7 +65,7 @@ def _make_provider(**overrides) -> _AzureAISearchContextProvider: "api_key": "test-key", } defaults.update(overrides) - provider = _AzureAISearchContextProvider(**defaults) + provider = AzureAISearchContextProvider(**defaults) provider._auto_discovered_vector_field = True # skip auto-discovery return provider @@ -89,7 +89,7 @@ def test_source_id_set(self) -> None: def test_missing_endpoint_raises(self) -> None: with patch.dict(os.environ, {}, clear=True), pytest.raises(ServiceInitializationError, match="endpoint"): - _AzureAISearchContextProvider( + AzureAISearchContextProvider( source_id="s", endpoint=None, index_name="idx", @@ -98,7 +98,7 @@ def test_missing_endpoint_raises(self) -> None: def test_missing_index_name_semantic_raises(self) -> None: with pytest.raises(ServiceInitializationError, match="index name"): - _AzureAISearchContextProvider( + AzureAISearchContextProvider( source_id="s", endpoint="https://test.search.windows.net", index_name=None, @@ -112,7 +112,7 @@ def test_env_variable_fallback(self) -> None: "AZURE_SEARCH_API_KEY": "env-key", } with patch.dict(os.environ, env, clear=False): - provider = _AzureAISearchContextProvider(source_id="env-test") + provider = AzureAISearchContextProvider(source_id="env-test") assert provider.endpoint == "https://env.search.windows.net" assert provider.index_name == "env-index" @@ -125,7 +125,7 @@ class TestInitAgenticValidation: def test_both_index_and_kb_raises(self) -> None: with pytest.raises(ServiceInitializationError, match="not both"): - _AzureAISearchContextProvider( + AzureAISearchContextProvider( source_id="s", endpoint="https://test.search.windows.net", index_name="idx", @@ -138,7 +138,7 @@ def test_both_index_and_kb_raises(self) -> None: def test_neither_index_nor_kb_raises(self) -> None: with pytest.raises(ServiceInitializationError, match="provide either"): - _AzureAISearchContextProvider( + AzureAISearchContextProvider( source_id="s", endpoint="https://test.search.windows.net", api_key="key", @@ -147,7 +147,7 @@ def test_neither_index_nor_kb_raises(self) -> None: def test_missing_model_deployment_name_raises(self) -> None: with pytest.raises(ServiceInitializationError, match="model_deployment_name"): - _AzureAISearchContextProvider( + AzureAISearchContextProvider( source_id="s", endpoint="https://test.search.windows.net", index_name="idx", @@ -158,7 +158,7 @@ def test_missing_model_deployment_name_raises(self) -> None: def test_vector_field_without_embedding_raises(self) -> None: with pytest.raises(ValueError, match="embedding_function"): - _AzureAISearchContextProvider( + AzureAISearchContextProvider( source_id="s", endpoint="https://test.search.windows.net", index_name="idx", diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_agent_provider.py b/python/packages/azure-ai/agent_framework_azure_ai/_agent_provider.py index 5ea9983c50..f5c5201531 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_agent_provider.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_agent_provider.py @@ -9,7 +9,7 @@ from agent_framework import ( AGENT_FRAMEWORK_USER_AGENT, Agent, - ContextProvider, + BaseContextProvider, FunctionTool, MiddlewareTypes, normalize_tools, @@ -176,7 +176,7 @@ async def create_agent( | None = None, default_options: OptionsCoT | None = None, middleware: Sequence[MiddlewareTypes] | None = None, - context_provider: ContextProvider | None = None, + context_providers: Sequence[BaseContextProvider] | None = None, ) -> Agent[OptionsCoT]: """Create a new agent on the Azure AI service and return a Agent. @@ -195,7 +195,7 @@ async def create_agent( default_options: A TypedDict containing default chat options for the agent. These options are applied to every run unless overridden. middleware: List of middleware to intercept agent and function invocations. - context_provider: Context provider to include during agent invocation. + context_providers: Context providers to include during agent invocation. Returns: Agent: A Agent instance configured with the created agent. @@ -259,7 +259,7 @@ async def create_agent( normalized_tools, default_options=default_options, middleware=middleware, - context_provider=context_provider, + context_providers=context_providers, ) async def get_agent( @@ -273,7 +273,7 @@ async def get_agent( | None = None, default_options: OptionsCoT | None = None, middleware: Sequence[MiddlewareTypes] | None = None, - context_provider: ContextProvider | None = None, + context_providers: Sequence[BaseContextProvider] | None = None, ) -> Agent[OptionsCoT]: """Retrieve an existing agent from the service and return a Agent. @@ -289,7 +289,7 @@ async def get_agent( default_options: A TypedDict containing default chat options for the agent. These options are applied to every run unless overridden. middleware: List of middleware to intercept agent and function invocations. - context_provider: Context provider to include during agent invocation. + context_providers: Context providers to include during agent invocation. Returns: Agent: A Agent instance configured with the retrieved agent. @@ -316,7 +316,7 @@ async def get_agent( normalized_tools, default_options=default_options, middleware=middleware, - context_provider=context_provider, + context_providers=context_providers, ) def as_agent( @@ -329,7 +329,7 @@ def as_agent( | None = None, default_options: OptionsCoT | None = None, middleware: Sequence[MiddlewareTypes] | None = None, - context_provider: ContextProvider | None = None, + context_providers: Sequence[BaseContextProvider] | None = None, ) -> Agent[OptionsCoT]: """Wrap an existing Agent SDK object as a Agent without making HTTP calls. @@ -343,7 +343,7 @@ def as_agent( default_options: A TypedDict containing default chat options for the agent. These options are applied to every run unless overridden. middleware: List of middleware to intercept agent and function invocations. - context_provider: Context provider to include during agent invocation. + context_providers: Context providers to include during agent invocation. Returns: Agent: A Agent instance configured with the agent. @@ -373,7 +373,7 @@ def as_agent( normalized_tools, default_options=default_options, middleware=middleware, - context_provider=context_provider, + context_providers=context_providers, ) def _to_chat_agent_from_agent( @@ -382,7 +382,7 @@ def _to_chat_agent_from_agent( provided_tools: Sequence[FunctionTool | MutableMapping[str, Any]] | None = None, default_options: OptionsCoT | None = None, middleware: Sequence[MiddlewareTypes] | None = None, - context_provider: ContextProvider | None = None, + context_providers: Sequence[BaseContextProvider] | None = None, ) -> Agent[OptionsCoT]: """Create a Agent from an Agent SDK object. @@ -392,7 +392,7 @@ def _to_chat_agent_from_agent( default_options: A TypedDict containing default chat options for the agent. These options are applied to every run unless overridden. middleware: List of middleware to intercept agent and function invocations. - context_provider: Context provider to include during agent invocation. + context_providers: Context providers to include during agent invocation. """ # Create the underlying client client = AzureAIAgentClient( @@ -416,7 +416,7 @@ def _to_chat_agent_from_agent( tools=merged_tools, default_options=default_options, # type: ignore[arg-type] middleware=middleware, - context_provider=context_provider, + context_providers=context_providers, ) def _merge_tools( diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py index a898117a92..3b3d8b53d2 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py @@ -14,15 +14,14 @@ AGENT_FRAMEWORK_USER_AGENT, Agent, Annotation, + BaseContextProvider, BaseChatClient, ChatAndFunctionMiddlewareTypes, - ChatMessageStoreProtocol, ChatMiddlewareLayer, ChatOptions, ChatResponse, ChatResponseUpdate, Content, - ContextProvider, FunctionInvocationConfiguration, FunctionInvocationLayer, FunctionTool, @@ -1434,8 +1433,7 @@ def as_agent( | Sequence[FunctionTool | Callable[..., Any] | MutableMapping[str, Any]] | None = None, default_options: AzureAIAgentOptionsT | Mapping[str, Any] | None = None, - chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None, - context_provider: ContextProvider | None = None, + context_providers: Sequence[BaseContextProvider] | None = None, middleware: Sequence[MiddlewareTypes] | None = None, **kwargs: Any, ) -> Agent[AzureAIAgentOptionsT]: @@ -1455,8 +1453,7 @@ def as_agent( instructions: Optional instructions for the agent. tools: The tools to use for the request. default_options: A TypedDict containing chat options. - chat_message_store_factory: Factory function to create an instance of ChatMessageStoreProtocol. - context_provider: Context providers to include during agent invocation. + context_providers: Context providers to include during agent invocation. middleware: List of middleware to intercept agent and function invocations. kwargs: Any additional keyword arguments. @@ -1470,8 +1467,7 @@ def as_agent( instructions=instructions, tools=tools, default_options=default_options, - chat_message_store_factory=chat_message_store_factory, - context_provider=context_provider, + context_providers=context_providers, middleware=middleware, **kwargs, ) diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_client.py index a5881c4c2a..79a30b0d81 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_client.py @@ -9,10 +9,9 @@ from agent_framework import ( AGENT_FRAMEWORK_USER_AGENT, Agent, + BaseContextProvider, ChatAndFunctionMiddlewareTypes, - ChatMessageStoreProtocol, ChatMiddlewareLayer, - ContextProvider, FunctionInvocationConfiguration, FunctionInvocationLayer, FunctionTool, @@ -808,8 +807,7 @@ def as_agent( | Sequence[FunctionTool | Callable[..., Any] | MutableMapping[str, Any]] | None = None, default_options: AzureAIClientOptionsT | Mapping[str, Any] | None = None, - chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None, - context_provider: ContextProvider | None = None, + context_providers: Sequence[BaseContextProvider] | None = None, middleware: Sequence[MiddlewareTypes] | None = None, **kwargs: Any, ) -> Agent[AzureAIClientOptionsT]: @@ -829,8 +827,7 @@ def as_agent( instructions: Optional instructions for the agent. tools: The tools to use for the request. default_options: A TypedDict containing chat options. - chat_message_store_factory: Factory function to create an instance of ChatMessageStoreProtocol. - context_provider: Context providers to include during agent invocation. + context_providers: Context providers to include during agent invocation. middleware: List of middleware to intercept agent and function invocations. kwargs: Any additional keyword arguments. @@ -844,8 +841,7 @@ def as_agent( instructions=instructions, tools=tools, default_options=default_options, - chat_message_store_factory=chat_message_store_factory, - context_provider=context_provider, + context_providers=context_providers, middleware=middleware, **kwargs, ) diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_project_provider.py b/python/packages/azure-ai/agent_framework_azure_ai/_project_provider.py index cdf5cad5cf..f2faffdb99 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_project_provider.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_project_provider.py @@ -9,7 +9,7 @@ from agent_framework import ( AGENT_FRAMEWORK_USER_AGENT, Agent, - ContextProvider, + BaseContextProvider, FunctionTool, MiddlewareTypes, get_logger, @@ -168,7 +168,7 @@ async def create_agent( | None = None, default_options: OptionsCoT | None = None, middleware: Sequence[MiddlewareTypes] | None = None, - context_provider: ContextProvider | None = None, + context_providers: Sequence[BaseContextProvider] | None = None, ) -> Agent[OptionsCoT]: """Create a new agent on the Azure AI service and return a local Agent wrapper. @@ -182,7 +182,7 @@ async def create_agent( default_options: A TypedDict containing default chat options for the agent. These options are applied to every run unless overridden. middleware: List of middleware to intercept agent and function invocations. - context_provider: Context provider to include during agent invocation. + context_providers: Context providers to include during agent invocation. Returns: Agent: A Agent instance configured with the created agent. @@ -255,7 +255,7 @@ async def create_agent( normalized_tools, default_options=default_options, middleware=middleware, - context_provider=context_provider, + context_providers=context_providers, ) async def get_agent( @@ -270,7 +270,7 @@ async def get_agent( | None = None, default_options: OptionsCoT | None = None, middleware: Sequence[MiddlewareTypes] | None = None, - context_provider: ContextProvider | None = None, + context_providers: Sequence[BaseContextProvider] | None = None, ) -> Agent[OptionsCoT]: """Retrieve an existing agent from the Azure AI service and return a local Agent wrapper. @@ -284,7 +284,7 @@ async def get_agent( default_options: A TypedDict containing default chat options for the agent. These options are applied to every run unless overridden. middleware: List of middleware to intercept agent and function invocations. - context_provider: Context provider to include during agent invocation. + context_providers: Context providers to include during agent invocation. Returns: Agent: A Agent instance configured with the retrieved agent. @@ -317,7 +317,7 @@ async def get_agent( normalize_tools(tools), default_options=default_options, middleware=middleware, - context_provider=context_provider, + context_providers=context_providers, ) def as_agent( @@ -330,7 +330,7 @@ def as_agent( | None = None, default_options: OptionsCoT | None = None, middleware: Sequence[MiddlewareTypes] | None = None, - context_provider: ContextProvider | None = None, + context_providers: Sequence[BaseContextProvider] | None = None, ) -> Agent[OptionsCoT]: """Wrap an SDK agent version object into a Agent without making HTTP calls. @@ -342,7 +342,7 @@ def as_agent( default_options: A TypedDict containing default chat options for the agent. These options are applied to every run unless overridden. middleware: List of middleware to intercept agent and function invocations. - context_provider: Context provider to include during agent invocation. + context_providers: Context providers to include during agent invocation. Returns: Agent: A Agent instance configured with the agent version. @@ -361,7 +361,7 @@ def as_agent( normalize_tools(tools), default_options=default_options, middleware=middleware, - context_provider=context_provider, + context_providers=context_providers, ) def _to_chat_agent_from_details( @@ -370,7 +370,7 @@ def _to_chat_agent_from_details( provided_tools: Sequence[FunctionTool | MutableMapping[str, Any]] | None = None, default_options: OptionsCoT | None = None, middleware: Sequence[MiddlewareTypes] | None = None, - context_provider: ContextProvider | None = None, + context_providers: Sequence[BaseContextProvider] | None = None, ) -> Agent[OptionsCoT]: """Create a Agent from an AgentVersionDetails. @@ -381,7 +381,7 @@ def _to_chat_agent_from_details( default_options: A TypedDict containing default chat options for the agent. These options are applied to every run unless overridden. middleware: List of middleware to intercept agent and function invocations. - context_provider: Context provider to include during agent invocation. + context_providers: Context providers to include during agent invocation. """ if not isinstance(details.definition, PromptAgentDefinition): raise ValueError("Agent definition must be PromptAgentDefinition to get a Agent.") @@ -409,7 +409,7 @@ def _to_chat_agent_from_details( tools=merged_tools, default_options=default_options, # type: ignore[arg-type] middleware=middleware, - context_provider=context_provider, + context_providers=context_providers, ) def _merge_tools( diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index 724b95015b..886e0d8588 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -135,8 +135,8 @@ class AgentFunctionApp(DFAppBase): @app.orchestration_trigger(context_name="context") def my_orchestration(context): writer = app.get_agent(context, "WeatherAgent") - thread = writer.get_new_thread() - forecast_task = writer.run("What's the forecast?", thread=thread) + session = writer.create_session() + forecast_task = writer.run("What's the forecast?", session=session) forecast = yield forecast_task return forecast diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py index 4e55fe1819..5875f9119b 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, TypeAlias import azure.durable_functions as df -from agent_framework import AgentThread, get_logger +from agent_framework import AgentSession, get_logger from agent_framework_durabletask import ( DurableAgentExecutor, RunRequest, @@ -178,11 +178,11 @@ def run_durable_agent( self, agent_name: str, run_request: RunRequest, - thread: AgentThread | None = None, + session: AgentSession | None = None, ) -> AgentTask: # Resolve session - session_id = self._create_session_id(agent_name, thread) + session_id = self._create_session_id(agent_name, session) entity_id = df.EntityId( name=session_id.entity_name, diff --git a/python/packages/azurefunctions/tests/test_orchestration.py b/python/packages/azurefunctions/tests/test_orchestration.py index e778875887..d1be7d9a77 100644 --- a/python/packages/azurefunctions/tests/test_orchestration.py +++ b/python/packages/azurefunctions/tests/test_orchestration.py @@ -214,10 +214,10 @@ def test_fire_and_forget_calls_signal_entity(self, executor_with_uuid: tuple[Any context.call_entity = Mock(return_value=_create_entity_task()) agent = DurableAIAgent(executor, "TestAgent") - thread = agent.get_new_thread() + session = agent.create_session() # Run with wait_for_response=False - result = agent.run("Test message", thread=thread, options={"wait_for_response": False}) + result = agent.run("Test message", session=session, options={"wait_for_response": False}) # Verify signal_entity was called and call_entity was not assert context.signal_entity.call_count == 1 @@ -232,9 +232,9 @@ def test_fire_and_forget_returns_completed_task(self, executor_with_uuid: tuple[ context.signal_entity = Mock() agent = DurableAIAgent(executor, "TestAgent") - thread = agent.get_new_thread() + session = agent.create_session() - result = agent.run("Test message", thread=thread, options={"wait_for_response": False}) + result = agent.run("Test message", session=session, options={"wait_for_response": False}) # Task should be immediately complete assert isinstance(result, AgentTask) @@ -246,9 +246,9 @@ def test_fire_and_forget_returns_acceptance_response(self, executor_with_uuid: t context.signal_entity = Mock() agent = DurableAIAgent(executor, "TestAgent") - thread = agent.get_new_thread() + session = agent.create_session() - result = agent.run("Test message", thread=thread, options={"wait_for_response": False}) + result = agent.run("Test message", session=session, options={"wait_for_response": False}) # Get the result response = result.result @@ -267,9 +267,9 @@ def test_blocking_mode_still_works(self, executor_with_uuid: tuple[Any, Mock, st context.call_entity = Mock(return_value=_create_entity_task()) agent = DurableAIAgent(executor, "TestAgent") - thread = agent.get_new_thread() + session = agent.create_session() - result = agent.run("Test message", thread=thread, options={"wait_for_response": True}) + result = agent.run("Test message", session=session, options={"wait_for_response": True}) # Verify call_entity was called and signal_entity was not assert context.call_entity.call_count == 1 @@ -298,15 +298,15 @@ def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dic # Create agent directly with executor (not via app.get_agent) agent = DurableAIAgent(executor, "WriterAgent") - # Create thread - thread = agent.get_new_thread() + # Create session + session = agent.create_session() # First call - returns AgentTask - task1 = agent.run("Write something", thread=thread) + task1 = agent.run("Write something", session=session) assert isinstance(task1, AgentTask) # Second call - returns AgentTask - task2 = agent.run("Improve: something", thread=thread) + task2 = agent.run("Improve: something", session=session) assert isinstance(task2, AgentTask) # Verify both calls used the same entity (same session key) @@ -315,7 +315,7 @@ def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dic # EntityId format is @dafx-writeragent@ expected_entity_id = f"@dafx-writeragent@{uuid_hexes[0]}" assert entity_calls[0]["entity_id"] == expected_entity_id - # generate_unique_id called 3 times: thread + 2 correlation IDs + # generate_unique_id called 3 times: session + 2 correlation IDs assert executor.generate_unique_id.call_count == 3 def test_multiple_agents_in_orchestration(self, executor_with_multiple_uuids: tuple[Any, Mock, list[str]]) -> None: @@ -334,12 +334,12 @@ def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dic writer = DurableAIAgent(executor, "WriterAgent") editor = DurableAIAgent(executor, "EditorAgent") - writer_thread = writer.get_new_thread() - editor_thread = editor.get_new_thread() + writer_session = writer.create_session() + editor_session = editor.create_session() # Call both agents - returns AgentTasks - writer_task = writer.run("Write", thread=writer_thread) - editor_task = editor.run("Edit", thread=editor_thread) + writer_task = writer.run("Write", session=writer_session) + editor_task = editor.run("Edit", session=editor_session) assert isinstance(writer_task, AgentTask) assert isinstance(editor_task, AgentTask) diff --git a/python/packages/claude/agent_framework_claude/_agent.py b/python/packages/claude/agent_framework_claude/_agent.py index 3e900b8e27..10ef5cbf45 100644 --- a/python/packages/claude/agent_framework_claude/_agent.py +++ b/python/packages/claude/agent_framework_claude/_agent.py @@ -12,12 +12,13 @@ AgentMiddlewareTypes, AgentResponse, AgentResponseUpdate, - AgentThread, + AgentSession, BaseAgent, + BaseContextProvider, Content, - ContextProvider, FunctionTool, Message, + ResponseStream, get_logger, normalize_messages, ) @@ -184,9 +185,9 @@ class ClaudeAgent(BaseAgent, Generic[OptionsT]): .. code-block:: python async with ClaudeAgent() as agent: - thread = agent.get_new_thread() - await agent.run("Remember my name is Alice", thread=thread) - response = await agent.run("What's my name?", thread=thread) + session = agent.create_session() + await agent.run("Remember my name is Alice", session=session) + response = await agent.run("What's my name?", session=session) # Claude will remember "Alice" from the same session With Agent Framework tools: @@ -214,7 +215,7 @@ def __init__( id: str | None = None, name: str | None = None, description: str | None = None, - context_provider: ContextProvider | None = None, + context_providers: Sequence[BaseContextProvider] | None = None, middleware: Sequence[AgentMiddlewareTypes] | None = None, tools: FunctionTool | Callable[..., Any] @@ -237,7 +238,7 @@ def __init__( id: Unique identifier for the agent. name: Name of the agent. description: Description of the agent. - context_provider: Context provider for the agent. + context_providers: Context providers for the agent. middleware: List of middleware. tools: Tools for the agent. Can be: - Strings for built-in tools (e.g., "Read", "Write", "Bash", "Glob") @@ -250,7 +251,7 @@ def __init__( id=id, name=name, description=description, - context_provider=context_provider, + context_providers=context_providers, middleware=middleware, ) @@ -559,7 +560,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: Literal[True], - thread: AgentThread | None = None, + session: AgentSession | None = None, options: OptionsT | MutableMapping[str, Any] | None = None, **kwargs: Any, ) -> AsyncIterable[AgentResponseUpdate]: ... @@ -570,7 +571,7 @@ async def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: Literal[False] = ..., - thread: AgentThread | None = None, + session: AgentSession | None = None, options: OptionsT | MutableMapping[str, Any] | None = None, **kwargs: Any, ) -> AgentResponse[Any]: ... @@ -580,7 +581,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, options: OptionsT | MutableMapping[str, Any] | None = None, **kwargs: Any, ) -> AsyncIterable[AgentResponseUpdate] | Awaitable[AgentResponse[Any]]: @@ -592,46 +593,36 @@ def run( Keyword Args: stream: If True, returns an async iterable of updates. If False (default), returns an awaitable AgentResponse. - thread: The conversation thread. If thread has service_thread_id set, + session: The conversation session. If session has service_session_id set, the agent will resume that session. options: Runtime options (model, permission_mode can be changed per-request). kwargs: Additional keyword arguments. Returns: - When stream=True: An AsyncIterable[AgentResponseUpdate] for streaming updates. + When stream=True: An ResponseStream for streaming updates. When stream=False: An Awaitable[AgentResponse] with the complete response. """ - if stream: - return self._run_streaming(messages, thread=thread, options=options, **kwargs) - return self._run_non_streaming(messages, thread=thread, options=options, **kwargs) - - async def _run_non_streaming( - self, - messages: str | Message | Sequence[str | Message] | None = None, - *, - thread: AgentThread | None = None, - options: OptionsT | MutableMapping[str, Any] | None = None, - **kwargs: Any, - ) -> AgentResponse[Any]: - """Internal non-streaming implementation.""" - thread = thread or self.get_new_thread() - return await AgentResponse.from_update_generator( - self._run_streaming(messages, thread=thread, options=options, **kwargs) + response = ResponseStream( + self._get_stream(messages, session=session, options=options, **kwargs), + finalizer=AgentResponse.from_updates, ) + if stream: + return response + return response.get_final_response() - async def _run_streaming( + async def _get_stream( self, messages: str | Message | Sequence[str | Message] | None = None, *, - thread: AgentThread | None = None, + session: AgentSession | None = None, options: OptionsT | MutableMapping[str, Any] | None = None, **kwargs: Any, ) -> AsyncIterable[AgentResponseUpdate]: """Internal streaming implementation.""" - thread = thread or self.get_new_thread() + session = session or self.create_session() # Ensure we're connected to the right session - await self._ensure_session(thread.service_thread_id) + await self._ensure_session(session.service_session_id) if not self._client: raise ServiceException("Claude SDK client not initialized.") @@ -696,6 +687,6 @@ async def _run_streaming( raise ServiceException(f"Claude API error: {error_msg}") session_id = message.session_id - # Update thread with session ID + # Update session with session ID if session_id: - thread.service_thread_id = session_id + session.service_session_id = session_id diff --git a/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py b/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py index c0e395b210..a3729d325d 100644 --- a/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py +++ b/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py @@ -9,10 +9,10 @@ AgentMiddlewareTypes, AgentResponse, AgentResponseUpdate, - AgentThread, + AgentSession, BaseAgent, + BaseContextProvider, Content, - ContextProvider, Message, ResponseStream, normalize_messages, @@ -59,7 +59,7 @@ def __init__( id: str | None = None, name: str | None = None, description: str | None = None, - context_provider: ContextProvider | None = None, + context_providers: Sequence[BaseContextProvider] | None = None, middleware: list[AgentMiddlewareTypes] | None = None, environment_id: str | None = None, agent_identifier: str | None = None, @@ -87,7 +87,7 @@ def __init__( id: id of the CopilotAgent name: Name of the CopilotAgent description: Description of the CopilotAgent - context_provider: Context Provider, to be used by the copilot agent. + context_providers: Context Providers, to be used by the copilot agent. middleware: Agent middleware used by the agent, should be a list of AgentMiddlewareTypes. environment_id: Environment ID of the Power Platform environment containing the Copilot Studio app. Can also be set via COPILOTSTUDIOAGENT__ENVIRONMENTID @@ -118,7 +118,7 @@ def __init__( id=id, name=name, description=description, - context_provider=context_provider, + context_providers=context_providers, middleware=middleware, ) if not client: @@ -190,7 +190,7 @@ def run( messages: str | Message | list[str] | list[Message] | None = None, *, stream: Literal[False] = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse]: ... @@ -200,7 +200,7 @@ def run( messages: str | Message | list[str] | list[Message] | None = None, *, stream: Literal[True], - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: ... @@ -209,7 +209,7 @@ def run( messages: str | Message | list[str] | list[Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: """Get a response from the agent. @@ -223,7 +223,7 @@ def run( Keyword Args: stream: Whether to stream the response. Defaults to False. - thread: The conversation thread associated with the message(s). + session: The conversation session associated with the message(s). kwargs: Additional keyword arguments. Returns: @@ -231,26 +231,26 @@ def run( When stream=True: A ResponseStream of AgentResponseUpdate items. """ if stream: - return self._run_stream_impl(messages=messages, thread=thread, **kwargs) - return self._run_impl(messages=messages, thread=thread, **kwargs) + return self._run_stream_impl(messages=messages, session=session, **kwargs) + return self._run_impl(messages=messages, session=session, **kwargs) async def _run_impl( self, messages: str | Message | list[str] | list[Message] | None = None, *, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> AgentResponse: """Non-streaming implementation of run.""" - if not thread: - thread = self.get_new_thread() - thread.service_thread_id = await self._start_new_conversation() + if not session: + session = self.create_session() + session.service_session_id = await self._start_new_conversation() input_messages = normalize_messages(messages) question = "\n".join([message.text for message in input_messages]) - activities = self.client.ask_question(question, thread.service_thread_id) + activities = self.client.ask_question(question, session.service_session_id) response_messages: list[Message] = [] response_id: str | None = None @@ -263,22 +263,22 @@ def _run_stream_impl( self, messages: str | Message | list[str] | list[Message] | None = None, *, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: """Streaming implementation of run.""" async def _stream() -> AsyncIterable[AgentResponseUpdate]: - nonlocal thread - if not thread: - thread = self.get_new_thread() - thread.service_thread_id = await self._start_new_conversation() + nonlocal session + if not session: + session = self.create_session() + session.service_session_id = await self._start_new_conversation() input_messages = normalize_messages(messages) question = "\n".join([message.text for message in input_messages]) - activities = self.client.ask_question(question, thread.service_thread_id) + activities = self.client.ask_question(question, session.service_session_id) async for message in self._process_activities(activities, streaming=True): yield AgentResponseUpdate( diff --git a/python/packages/copilotstudio/tests/test_copilot_agent.py b/python/packages/copilotstudio/tests/test_copilot_agent.py index fb16f151f3..2bc97fe650 100644 --- a/python/packages/copilotstudio/tests/test_copilot_agent.py +++ b/python/packages/copilotstudio/tests/test_copilot_agent.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch import pytest -from agent_framework import AgentResponse, AgentResponseUpdate, AgentThread, Content, Message +from agent_framework import AgentResponse, AgentResponseUpdate, AgentSession, Content, Message from agent_framework.exceptions import ServiceException, ServiceInitializationError from microsoft_agents.copilotstudio.client import CopilotClient @@ -165,10 +165,10 @@ async def test_run_with_chat_message(self, mock_copilot_client: MagicMock, mock_ assert content.text == "Test response" assert response.messages[0].role == "assistant" - async def test_run_with_thread(self, mock_copilot_client: MagicMock, mock_activity: MagicMock) -> None: - """Test run method with existing thread.""" + async def test_run_with_session(self, mock_copilot_client: MagicMock, mock_activity: MagicMock) -> None: + """Test run method with existing session.""" agent = CopilotStudioAgent(client=mock_copilot_client) - thread = AgentThread() + session = AgentSession() conversation_activity = MagicMock() conversation_activity.conversation.id = "test-conversation-id" @@ -176,11 +176,11 @@ async def test_run_with_thread(self, mock_copilot_client: MagicMock, mock_activi mock_copilot_client.start_conversation.return_value = create_async_generator([conversation_activity]) mock_copilot_client.ask_question.return_value = create_async_generator([mock_activity]) - response = await agent.run("test message", thread=thread) + response = await agent.run("test message", session=session) assert isinstance(response, AgentResponse) assert len(response.messages) == 1 - assert thread.service_thread_id == "test-conversation-id" + assert session.service_session_id == "test-conversation-id" async def test_run_start_conversation_failure(self, mock_copilot_client: MagicMock) -> None: """Test run method when conversation start fails.""" @@ -217,10 +217,10 @@ async def test_run_streaming_with_string_message(self, mock_copilot_client: Magi assert response_count == 1 - async def test_run_streaming_with_thread(self, mock_copilot_client: MagicMock) -> None: - """Test run(stream=True) method with existing thread.""" + async def test_run_streaming_with_session(self, mock_copilot_client: MagicMock) -> None: + """Test run(stream=True) method with existing session.""" agent = CopilotStudioAgent(client=mock_copilot_client) - thread = AgentThread() + session = AgentSession() conversation_activity = MagicMock() conversation_activity.conversation.id = "test-conversation-id" @@ -235,7 +235,7 @@ async def test_run_streaming_with_thread(self, mock_copilot_client: MagicMock) - mock_copilot_client.ask_question.return_value = create_async_generator([typing_activity]) response_count = 0 - async for response in agent.run("test message", thread=thread, stream=True): + async for response in agent.run("test message", session=session, stream=True): assert isinstance(response, AgentResponseUpdate) content = response.contents[0] assert content.type == "text" @@ -243,7 +243,7 @@ async def test_run_streaming_with_thread(self, mock_copilot_client: MagicMock) - response_count += 1 assert response_count == 1 - assert thread.service_thread_id == "test-conversation-id" + assert session.service_session_id == "test-conversation-id" async def test_run_streaming_no_typing_activity(self, mock_copilot_client: MagicMock) -> None: """Test run(stream=True) method with non-typing activity.""" diff --git a/python/packages/core/agent_framework/__init__.py b/python/packages/core/agent_framework/__init__.py index 1e408169d1..041aa17306 100644 --- a/python/packages/core/agent_framework/__init__.py +++ b/python/packages/core/agent_framework/__init__.py @@ -15,6 +15,7 @@ from ._mcp import * # noqa: F403 from ._memory import * # noqa: F403 from ._middleware import * # noqa: F403 +from ._sessions import * # noqa: F403 from ._telemetry import * # noqa: F403 from ._threads import * # noqa: F403 from ._tools import * # noqa: F403 diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index f43abb9fa7..10c5bc4b58 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -31,10 +31,9 @@ from ._clients import BaseChatClient, SupportsChatGetResponse from ._logging import get_logger from ._mcp import LOG_LEVEL_MAPPING, MCPTool -from ._memory import Context, ContextProvider from ._middleware import AgentMiddlewareLayer, MiddlewareTypes from ._serialization import SerializationMixin -from ._threads import AgentThread, ChatMessageStoreProtocol +from ._sessions import AgentSession, BaseContextProvider, BaseHistoryProvider, InMemoryHistoryProvider, SessionContext from ._tools import ( FunctionInvocationLayer, FunctionTool, @@ -49,7 +48,7 @@ map_chat_to_agent_update, normalize_messages, ) -from .exceptions import AgentExecutionException, AgentInitializationError +from .exceptions import AgentExecutionException from .observability import AgentTelemetryLayer if sys.version_info >= (3, 13): @@ -57,9 +56,9 @@ else: from typing_extensions import TypeVar # type: ignore # pragma: no cover if sys.version_info >= (3, 12): - from typing import override # type: ignore # pragma: no cover + pass # type: ignore # pragma: no cover else: - from typing_extensions import override # type: ignore[import] # pragma: no cover + pass # type: ignore[import] # pragma: no cover if sys.version_info >= (3, 11): from typing import Self, TypedDict # pragma: no cover else: @@ -75,7 +74,7 @@ logger = get_logger("agent_framework") -ThreadTypeT = TypeVar("ThreadTypeT", bound="AgentThread") +ThreadTypeT = TypeVar("ThreadTypeT", bound="AgentSession") OptionsCoT = TypeVar( "OptionsCoT", bound=TypedDict, # type: ignore[valid-type] @@ -155,7 +154,8 @@ def _sanitize_agent_name(agent_name: str | None) -> str | None: class _RunContext(TypedDict): - thread: AgentThread + session: AgentSession | None + session_context: SessionContext input_messages: list[Message] thread_messages: list[Message] agent_name: str @@ -197,7 +197,7 @@ def __init__(self): self.name = "Custom Agent" self.description = "A fully custom agent implementation" - async def run(self, messages=None, *, stream=False, thread=None, **kwargs): + async def run(self, messages=None, *, stream=False, session=None, **kwargs): if stream: # Your custom streaming implementation async def _stream(): @@ -212,9 +212,15 @@ async def _stream(): return AgentResponse(messages=[], response_id="custom-response") - def get_new_thread(self, **kwargs): - # Return your own thread implementation - return {"id": "custom-thread", "messages": []} + def create_session(self, **kwargs): + from agent_framework import AgentSession + + return AgentSession(**kwargs) + + def get_session(self, service_session_id, **kwargs): + from agent_framework import AgentSession + + return AgentSession(service_session_id=service_session_id, **kwargs) # Verify the instance satisfies the protocol @@ -232,7 +238,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: Literal[False] = ..., - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]]: """Get a response from the agent (non-streaming).""" @@ -244,7 +250,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: Literal[True], - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: """Get a streaming response from the agent.""" @@ -255,7 +261,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: """Get a response from the agent. @@ -269,7 +275,7 @@ def run( Keyword Args: stream: Whether to stream the response. Defaults to False. - thread: The conversation thread associated with the message(s). + session: The conversation session associated with the message(s). kwargs: Additional keyword arguments. Returns: @@ -279,8 +285,12 @@ def run( """ ... - def get_new_thread(self, **kwargs: Any) -> AgentThread: - """Creates a new conversation thread for the agent.""" + def create_session(self, **kwargs: Any) -> AgentSession: + """Creates a new conversation session.""" + ... + + def get_session(self, service_session_id: str, **kwargs: Any) -> AgentSession: + """Gets or creates a session for a service-managed session ID.""" ... @@ -294,7 +304,7 @@ class BaseAgent(SerializationMixin): For most use cases, prefer :class:`Agent` which includes all standard layers. This class provides core functionality for agent implementations, including - context providers, middleware support, and thread management. + context providers, middleware support, and session management. Note: BaseAgent cannot be instantiated directly as it doesn't implement the @@ -304,12 +314,12 @@ class BaseAgent(SerializationMixin): Examples: .. code-block:: python - from agent_framework import BaseAgent, AgentThread, AgentResponse + from agent_framework import BaseAgent, AgentSession, AgentResponse # Create a concrete subclass that implements the protocol class SimpleAgent(BaseAgent): - async def run(self, messages=None, *, stream=False, thread=None, **kwargs): + async def run(self, messages=None, *, stream=False, session=None, **kwargs): if stream: async def _stream(): @@ -345,7 +355,7 @@ def __init__( id: str | None = None, name: str | None = None, description: str | None = None, - context_provider: ContextProvider | None = None, + context_providers: Sequence[BaseContextProvider] | None = None, middleware: Sequence[MiddlewareTypes] | None = None, additional_properties: MutableMapping[str, Any] | None = None, **kwargs: Any, @@ -357,7 +367,7 @@ def __init__( a new UUID will be generated. name: The name of the agent, can be None. description: The description of the agent. - context_provider: The context provider to include during agent invocation. + context_providers: Context providers to include during agent invocation. middleware: List of middleware. additional_properties: Additional properties set on the agent. kwargs: Additional keyword arguments (merged into additional_properties). @@ -367,7 +377,7 @@ def __init__( self.id = id self.name = name self.description = description - self.context_provider = context_provider + self.context_providers: list[BaseContextProvider] = list(context_providers or []) self.middleware: list[MiddlewareTypes] | None = ( cast(list[MiddlewareTypes], middleware) if middleware is not None else None ) @@ -376,56 +386,32 @@ def __init__( self.additional_properties: dict[str, Any] = cast(dict[str, Any], additional_properties or {}) self.additional_properties.update(kwargs) - async def _notify_thread_of_new_messages( - self, - thread: AgentThread, - input_messages: Message | Sequence[Message], - response_messages: Message | Sequence[Message], - **kwargs: Any, - ) -> None: - """Notify the thread of new messages. - - This also calls the invoked method of a potential context provider on the thread. - - Args: - thread: The thread to notify of new messages. - input_messages: The input messages to notify about. - response_messages: The response messages to notify about. - **kwargs: Any extra arguments to pass from the agent run. - """ - if isinstance(input_messages, Message) or len(input_messages) > 0: - await thread.on_new_messages(input_messages) - if isinstance(response_messages, Message) or len(response_messages) > 0: - await thread.on_new_messages(response_messages) - if thread.context_provider: - await thread.context_provider.invoked(input_messages, response_messages, **kwargs) - - def get_new_thread(self, **kwargs: Any) -> AgentThread: - """Return a new AgentThread instance that is compatible with the agent. + def create_session(self, *, session_id: str | None = None, **kwargs: Any) -> AgentSession: + """Create a new lightweight session. Keyword Args: - kwargs: Additional keyword arguments passed to AgentThread. + session_id: Optional session ID (generated if not provided). + kwargs: Additional keyword arguments. Returns: - A new AgentThread instance configured with the agent's context provider. + A new AgentSession instance. """ - return AgentThread(**kwargs, context_provider=self.context_provider) + return AgentSession(session_id=session_id) - async def deserialize_thread(self, serialized_thread: Any, **kwargs: Any) -> AgentThread: - """Deserialize a thread from its serialized state. + def get_session(self, service_session_id: str, *, session_id: str | None = None, **kwargs: Any) -> AgentSession: + """Get or create a session for a service-managed session ID. Args: - serialized_thread: The serialized thread data. + service_session_id: The service-managed session ID. Keyword Args: + session_id: Optional local session ID (generated if not provided). kwargs: Additional keyword arguments. Returns: - A new AgentThread instance restored from the serialized state. + A new AgentSession instance with service_session_id set. """ - thread: AgentThread = self.get_new_thread() - await thread.update_from_thread_state(serialized_thread, **kwargs) - return thread + return AgentSession(session_id=session_id, service_session_id=service_session_id) def as_tool( self, @@ -621,8 +607,7 @@ def __init__( | Sequence[FunctionTool | Callable[..., Any] | MutableMapping[str, Any] | Any] | None = None, default_options: OptionsCoT | None = None, - chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None, - context_provider: ContextProvider | None = None, + context_providers: Sequence[BaseContextProvider] | None = None, **kwargs: Any, ) -> None: """Initialize a Agent instance. @@ -636,9 +621,7 @@ def __init__( id: The unique identifier for the agent. Will be created automatically if not provided. name: The name of the agent. description: A brief description of the agent's purpose. - chat_message_store_factory: Factory function to create an instance of ChatMessageStoreProtocol. - If not provided, the default in-memory store will be used. - context_provider: The context providers to include during agent invocation. + context_providers: Context providers to include during agent invocation. middleware: List of middleware to intercept agent and function invocations. default_options: A TypedDict containing chat options. When using a typed agent like ``Agent[OpenAIChatOptions]``, this enables IDE autocomplete for @@ -649,19 +632,8 @@ def __init__( These can be overridden at runtime via the ``options`` parameter of ``run()``. tools: The tools to use for the request. kwargs: Any additional keyword arguments. Will be stored as ``additional_properties``. - - Raises: - AgentInitializationError: If both conversation_id and chat_message_store_factory are provided. """ - # Extract conversation_id from options for validation opts = dict(default_options) if default_options else {} - conversation_id = opts.get("conversation_id") - - if conversation_id is not None and chat_message_store_factory is not None: - raise AgentInitializationError( - "Cannot specify both conversation_id and chat_message_store_factory. " - "Use conversation_id for service-managed threads or chat_message_store_factory for local storage." - ) if not isinstance(client, FunctionInvocationLayer) and isinstance(client, BaseChatClient): logger.warning( @@ -672,11 +644,10 @@ def __init__( id=id, name=name, description=description, - context_provider=context_provider, + context_providers=context_providers, **kwargs, ) self.client = client - self.chat_message_store_factory = chat_message_store_factory # Get tools from options or named parameter (named param takes precedence) tools_ = tools if tools is not None else opts.pop("tools", None) @@ -779,7 +750,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: Literal[False] = ..., - thread: AgentThread | None = None, + session: AgentSession | None = None, tools: FunctionTool | Callable[..., Any] | MutableMapping[str, Any] @@ -796,7 +767,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: Literal[False] = ..., - thread: AgentThread | None = None, + session: AgentSession | None = None, tools: FunctionTool | Callable[..., Any] | MutableMapping[str, Any] @@ -813,7 +784,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: Literal[True], - thread: AgentThread | None = None, + session: AgentSession | None = None, tools: FunctionTool | Callable[..., Any] | MutableMapping[str, Any] @@ -829,7 +800,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, tools: FunctionTool | Callable[..., Any] | MutableMapping[str, Any] @@ -852,7 +823,9 @@ def run( stream: Whether to stream the response. Defaults to False. Keyword Args: - thread: The thread to use for the agent. + session: The session to use for the agent. + If None, and no settings for the chat client that indicate otherwise, + the run will be stateless. tools: The tools to use for this specific run (merged with default tools). options: A TypedDict containing chat options. When using a typed agent like ``Agent[OpenAIChatOptions]``, this enables IDE autocomplete for @@ -871,7 +844,7 @@ def run( async def _run_non_streaming() -> AgentResponse[Any]: ctx = await self._prepare_run_context( messages=messages, - thread=thread, + session=session, tools=tools, options=options, kwargs=kwargs, @@ -886,12 +859,11 @@ async def _run_non_streaming() -> AgentResponse[Any]: if not response: raise AgentExecutionException("Chat client did not return a response.") - await self._finalize_response_and_update_thread( + await self._finalize_response( response=response, agent_name=ctx["agent_name"], - thread=ctx["thread"], - input_messages=ctx["input_messages"], - kwargs=ctx["finalize_kwargs"], + session=ctx["session"], + session_context=ctx["session_context"], ) response_format = ctx["chat_options"].get("response_format") if not ( @@ -923,26 +895,23 @@ async def _post_hook(response: AgentResponse) -> None: if ctx is None: return # No context available (shouldn't happen in normal flow) - # Update thread with conversation_id - await self._update_thread_with_type_and_conversation_id(ctx["thread"], response.response_id) - # Ensure author names are set for all messages for message in response.messages: if message.author_name is None: message.author_name = ctx["agent_name"] - # Notify thread of new messages - await self._notify_thread_of_new_messages( - ctx["thread"], - ctx["input_messages"], - response.messages, - **{k: v for k, v in ctx["finalize_kwargs"].items() if k != "thread"}, + # Run after_run providers (reverse order) + session_context = ctx["session_context"] + session_context._response = AgentResponse( # type: ignore[assignment] + messages=response.messages, + response_id=response.response_id, ) + await self._run_after_providers(session=ctx["session"], context=session_context) async def _get_stream() -> ResponseStream[ChatResponseUpdate, ChatResponse]: ctx_holder["ctx"] = await self._prepare_run_context( messages=messages, - thread=thread, + session=session, tools=tools, options=options, kwargs=kwargs, @@ -984,7 +953,7 @@ async def _prepare_run_context( self, *, messages: str | Message | Sequence[str | Message] | None, - thread: AgentThread | None, + session: AgentSession | None, tools: FunctionTool | Callable[..., Any] | MutableMapping[str, Any] @@ -1000,8 +969,22 @@ async def _prepare_run_context( tools_ = tools if tools is not None else opts.pop("tools", None) input_messages = normalize_messages(messages) - thread, run_chat_options, thread_messages = await self._prepare_thread_and_messages( - thread=thread, input_messages=input_messages, **kwargs + + # Auto-inject InMemoryHistoryProvider when session is provided, no context providers + # registered, and no service-side storage indicators + if ( + session is not None + and not self.context_providers + and not session.service_session_id + and not opts.get("conversation_id") + and not opts.get("store") + ): + self.context_providers.append(InMemoryHistoryProvider("memory")) + + session_context, chat_options = await self._prepare_session_and_messages( + session=session, + input_messages=input_messages, + options=opts, ) # Normalize tools @@ -1028,7 +1011,7 @@ async def _prepare_run_context( # Build options dict from run() options merged with provided options run_opts: dict[str, Any] = { "model_id": opts.pop("model_id", None), - "conversation_id": thread.service_thread_id, + "conversation_id": session.service_session_id if session else opts.pop("conversation_id", None), "allow_multiple_tool_calls": opts.pop("allow_multiple_tool_calls", None), "additional_function_arguments": opts.pop("additional_function_arguments", None), "frequency_penalty": opts.pop("frequency_penalty", None), @@ -1049,16 +1032,20 @@ async def _prepare_run_context( } # Remove None values and merge with chat_options run_opts = {k: v for k, v in run_opts.items() if v is not None} - co = _merge_options(run_chat_options, run_opts) + co = _merge_options(chat_options, run_opts) + + # Build thread_messages from session context: context messages + input messages + thread_messages: list[Message] = session_context.get_messages(include_input=True) - # Ensure thread is forwarded in kwargs for tool invocation + # Ensure session is forwarded in kwargs for tool invocation finalize_kwargs = dict(kwargs) - finalize_kwargs["thread"] = thread + finalize_kwargs["session"] = session # Filter chat_options from kwargs to prevent duplicate keyword argument filtered_kwargs = {k: v for k, v in finalize_kwargs.items() if k != "chat_options"} return { - "thread": thread, + "session": session, + "session_context": session_context, "input_messages": input_messages, "thread_messages": thread_messages, "agent_name": agent_name, @@ -1067,85 +1054,124 @@ async def _prepare_run_context( "finalize_kwargs": finalize_kwargs, } - async def _finalize_response_and_update_thread( + async def _finalize_response( self, response: ChatResponse, agent_name: str, - thread: AgentThread, - input_messages: list[Message], - kwargs: dict[str, Any], + session: AgentSession | None, + session_context: SessionContext, ) -> None: - """Finalize response by updating thread and setting author names. + """Finalize response by setting author names and running after_run providers. Args: response: The chat response to finalize. agent_name: The name of the agent to set as author. - thread: The conversation thread. - input_messages: The input messages. - kwargs: Additional keyword arguments. + session: The conversation session. + session_context: The invocation context. """ - await self._update_thread_with_type_and_conversation_id(thread, response.conversation_id) - # Ensure that the author name is set for each message in the response. for message in response.messages: if message.author_name is None: message.author_name = agent_name - # Only notify the thread of new messages if the chatResponse was successful - # to avoid inconsistent messages state in the thread. - await self._notify_thread_of_new_messages( - thread, - input_messages, - response.messages, - **{k: v for k, v in kwargs.items() if k != "thread"}, + # Set the response on the context for after_run providers + session_context._response = AgentResponse( # type: ignore[assignment] + messages=response.messages, + response_id=response.response_id, ) - @override - def get_new_thread( + # Run after_run providers (reverse order) + await self._run_after_providers(session=session, context=session_context) + + async def _run_after_providers( self, *, - service_thread_id: str | None = None, - **kwargs: Any, - ) -> AgentThread: - """Get a new conversation thread for the agent. - - If you supply a service_thread_id, the thread will be marked as service managed. + session: AgentSession | None, + context: SessionContext, + ) -> None: + """Run after_run on all context providers in reverse order. - If you don't supply a service_thread_id but have a conversation_id configured on the agent, - that conversation_id will be used to create a service-managed thread. + Keyword Args: + session: The conversation session. + context: The invocation context with response populated. + """ + state = session.state if session else {} + for provider in reversed(self.context_providers): + await provider.after_run( + agent=self, + session=session, # type: ignore[arg-type] + context=context, + state=state, + ) - If you don't supply a service_thread_id but have a chat_message_store_factory configured on the agent, - that factory will be used to create a message store for the thread and the thread will be - managed locally. + async def _prepare_session_and_messages( + self, + *, + session: AgentSession | None, + input_messages: list[Message] | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[SessionContext, dict[str, Any]]: + """Prepare the session context and messages for agent execution. - When neither is present, the thread will be created without a service ID or message store. - This will be updated based on usage when you run the agent with this thread. - If you run with ``store=True``, the response will include a thread_id and that will be set. - Otherwise a message store is created from the default factory. + Runs the before_run pipeline on all context providers and assembles + the chat options from default options and provider-contributed context. Keyword Args: - service_thread_id: Optional service managed thread ID. - kwargs: Not used at present. + session: The conversation session (None for stateless invocation). + input_messages: Messages to process. + options: Runtime options dict (already copied, safe to mutate). Returns: - A new AgentThread instance. + A tuple containing: + - The SessionContext with provider context populated + - The merged chat options dict """ - if service_thread_id is not None: - return AgentThread( - service_thread_id=service_thread_id, - context_provider=self.context_provider, - ) - if self.default_options.get("conversation_id") is not None: - return AgentThread( - service_thread_id=self.default_options["conversation_id"], - context_provider=self.context_provider, - ) - if self.chat_message_store_factory is not None: - return AgentThread( - message_store=self.chat_message_store_factory(), - context_provider=self.context_provider, + # Create a shallow copy of options and deep copy non-tool values + if self.default_options: + chat_options: dict[str, Any] = {} + for key, value in self.default_options.items(): + if key == "tools": + chat_options[key] = list(value) if value else [] + else: + chat_options[key] = deepcopy(value) + else: + chat_options = {} + + session_context = SessionContext( + session_id=session.session_id if session else None, + service_session_id=session.service_session_id if session else None, + input_messages=input_messages or [], + options=options or {}, + ) + + # Run before_run providers (forward order, skip BaseHistoryProvider with load_messages=False) + state = session.state if session else {} + for provider in self.context_providers: + if isinstance(provider, BaseHistoryProvider) and not provider.load_messages: + continue + await provider.before_run( + agent=self, + session=session, # type: ignore[arg-type] + context=session_context, + state=state, ) - return AgentThread(context_provider=self.context_provider) + + # Merge provider-contributed tools into chat_options + if session_context.tools: + if chat_options.get("tools") is not None: + chat_options["tools"].extend(session_context.tools) + else: + chat_options["tools"] = list(session_context.tools) + + # Merge provider-contributed instructions into chat_options + if session_context.instructions: + combined_instructions = "\n".join(session_context.instructions) + if "instructions" in chat_options: + chat_options["instructions"] = f"{chat_options['instructions']}\n{combined_instructions}" + else: + chat_options["instructions"] = combined_instructions + + return session_context, chat_options def as_mcp_server( self, @@ -1256,115 +1282,6 @@ async def _set_logging_level(level: types.LoggingLevel) -> None: # type: ignore return server - async def _update_thread_with_type_and_conversation_id( - self, thread: AgentThread, response_conversation_id: str | None - ) -> None: - """Update thread with storage type and conversation ID. - - Args: - thread: The thread to update. - response_conversation_id: The conversation ID from the response, if any. - - Raises: - AgentExecutionException: If conversation ID is missing for service-managed thread. - """ - if response_conversation_id is None and thread.service_thread_id is not None: - # We were passed a thread that is service managed, but we got no conversation id back from the chat client, - # meaning the service doesn't support service managed threads, - # so the thread cannot be used with this service. - raise AgentExecutionException( - "Service did not return a valid conversation id when using a service managed thread." - ) - - if response_conversation_id is not None: - # If we got a conversation id back from the chat client, it means that the service - # supports server side thread storage so we should update the thread with the new id. - thread.service_thread_id = response_conversation_id - if thread.context_provider: - await thread.context_provider.thread_created(thread.service_thread_id) - elif thread.message_store is None and self.chat_message_store_factory is not None: - # If the service doesn't use service side thread storage (i.e. we got no id back from invocation), and - # the thread has no message_store yet, and we have a custom messages store, we should update the thread - # with the custom message_store so that it has somewhere to store the chat history. - thread.message_store = self.chat_message_store_factory() - - async def _prepare_thread_and_messages( - self, - *, - thread: AgentThread | None, - input_messages: list[Message] | None = None, - **kwargs: Any, - ) -> tuple[AgentThread, dict[str, Any], list[Message]]: - """Prepare the thread and messages for agent execution. - - This method prepares the conversation thread, merges context provider data, - and assembles the final message list for the chat client. - - Keyword Args: - thread: The conversation thread. - input_messages: Messages to process. - **kwargs: Any extra arguments to pass from the agent run. - - Returns: - A tuple containing: - - The validated or created thread - - The merged chat options - - The complete list of messages for the chat client - - Raises: - AgentExecutionException: If the conversation IDs on the thread and agent don't match. - """ - # Create a shallow copy of options and deep copy non-tool values - # Tools containing HTTP clients or other non-copyable objects cannot be deep copied - if self.default_options: - chat_options: dict[str, Any] = {} - for key, value in self.default_options.items(): - if key == "tools": - # Keep tool references as-is (don't deep copy) - chat_options[key] = list(value) if value else [] - else: - # Deep copy other options to prevent mutation - chat_options[key] = deepcopy(value) - else: - chat_options = {} - thread = thread or self.get_new_thread() - if thread.service_thread_id and thread.context_provider: - await thread.context_provider.thread_created(thread.service_thread_id) - thread_messages: list[Message] = [] - if thread.message_store: - thread_messages.extend(await thread.message_store.list_messages() or []) - context: Context | None = None - if self.context_provider: - # Note: We don't use 'async with' here because the context provider's lifecycle - # should be managed by the user (via async with) or persist across multiple invocations. - # Using async with here would close resources (like retrieval clients) after each query. - context = await self.context_provider.invoking(input_messages or [], **kwargs) - if context: - if context.messages: - thread_messages.extend(context.messages) - if context.tools: - if chat_options.get("tools") is not None: - chat_options["tools"].extend(context.tools) - else: - chat_options["tools"] = list(context.tools) - if context.instructions: - chat_options["instructions"] = ( - context.instructions - if "instructions" not in chat_options - else f"{chat_options['instructions']}\n{context.instructions}" - ) - thread_messages.extend(input_messages or []) - if ( - thread.service_thread_id - and chat_options.get("conversation_id") - and thread.service_thread_id != chat_options["conversation_id"] - ): - raise AgentExecutionException( - "The conversation_id set on the agent is different from the one set on the thread, " - "only one ID can be used for a run." - ) - return thread, chat_options, thread_messages - def _get_agent_name(self) -> str: """Get the agent name for message attribution. @@ -1404,8 +1321,7 @@ def __init__( | Sequence[FunctionTool | Callable[..., Any] | MutableMapping[str, Any] | Any] | None = None, default_options: OptionsCoT | None = None, - chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None, - context_provider: ContextProvider | None = None, + context_providers: Sequence[BaseContextProvider] | None = None, middleware: Sequence[MiddlewareTypes] | None = None, **kwargs: Any, ) -> None: @@ -1418,8 +1334,7 @@ def __init__( description=description, tools=tools, default_options=default_options, - chat_message_store_factory=chat_message_store_factory, - context_provider=context_provider, + context_providers=context_providers, middleware=middleware, **kwargs, ) diff --git a/python/packages/core/agent_framework/_clients.py b/python/packages/core/agent_framework/_clients.py index 0c3523698e..f5abb1d999 100644 --- a/python/packages/core/agent_framework/_clients.py +++ b/python/packages/core/agent_framework/_clients.py @@ -28,9 +28,7 @@ from pydantic import BaseModel from ._logging import get_logger -from ._memory import ContextProvider from ._serialization import SerializationMixin -from ._threads import ChatMessageStoreProtocol from ._tools import ( FunctionInvocationConfiguration, FunctionTool, @@ -448,8 +446,7 @@ def as_agent( | Sequence[FunctionTool | Callable[..., Any] | MutableMapping[str, Any]] | None = None, default_options: OptionsCoT | Mapping[str, Any] | None = None, - chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None, - context_provider: ContextProvider | None = None, + context_providers: Sequence[Any] | None = None, middleware: Sequence[MiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, **kwargs: Any, @@ -471,9 +468,7 @@ def as_agent( including temperature, max_tokens, model_id, tool_choice, and more. Note: response_format typing does not flow into run outputs when set via default_options, and dict literals are accepted without specialized option typing. - chat_message_store_factory: Factory function to create an instance of ChatMessageStoreProtocol. - If not provided, the default in-memory store will be used. - context_provider: Context providers to include during agent invocation. + context_providers: Context providers to include during agent invocation. middleware: List of middleware to intercept agent and function invocations. function_invocation_configuration: Optional function invocation configuration override. kwargs: Any additional keyword arguments. Will be stored as ``additional_properties``. @@ -509,8 +504,7 @@ def as_agent( instructions=instructions, tools=tools, default_options=cast(Any, default_options), - chat_message_store_factory=chat_message_store_factory, - context_provider=context_provider, + context_providers=context_providers, middleware=middleware, function_invocation_configuration=function_invocation_configuration, **kwargs, diff --git a/python/packages/core/agent_framework/_middleware.py b/python/packages/core/agent_framework/_middleware.py index d096e96c2a..969ef1efc9 100644 --- a/python/packages/core/agent_framework/_middleware.py +++ b/python/packages/core/agent_framework/_middleware.py @@ -36,7 +36,7 @@ from ._agents import SupportsAgentRun from ._clients import SupportsChatGetResponse - from ._threads import AgentThread + from ._sessions import AgentSession from ._tools import FunctionTool from ._types import ChatOptions, ChatResponse, ChatResponseUpdate @@ -118,7 +118,7 @@ class AgentContext: Attributes: agent: The agent being invoked. messages: The messages being sent to the agent. - thread: The agent thread for this invocation, if any. + session: The agent session for this invocation, if any. options: The options for the agent invocation as a dict. stream: Whether this is a streaming invocation. metadata: Metadata dictionary for sharing data between agent middleware. @@ -138,7 +138,7 @@ class LoggingMiddleware(AgentMiddleware): async def process(self, context: AgentContext, call_next): print(f"Agent: {context.agent.name}") print(f"Messages: {len(context.messages)}") - print(f"Thread: {context.thread}") + print(f"Session: {context.session}") print(f"Streaming: {context.stream}") # Store metadata @@ -156,7 +156,7 @@ def __init__( *, agent: SupportsAgentRun, messages: list[Message], - thread: AgentThread | None = None, + session: AgentSession | None = None, options: Mapping[str, Any] | None = None, stream: bool = False, metadata: Mapping[str, Any] | None = None, @@ -175,7 +175,7 @@ def __init__( Args: agent: The agent being invoked. messages: The messages being sent to the agent. - thread: The agent thread for this invocation, if any. + session: The agent session for this invocation, if any. options: The options for the agent invocation as a dict. stream: Whether this is a streaming invocation. metadata: Metadata dictionary for sharing data between agent middleware. @@ -187,7 +187,7 @@ def __init__( """ self.agent = agent self.messages = messages - self.thread = thread + self.session = session self.options = options self.stream = stream self.metadata = metadata if metadata is not None else {} @@ -1098,7 +1098,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: Literal[False] = ..., - thread: AgentThread | None = None, + session: AgentSession | None = None, middleware: Sequence[MiddlewareTypes] | None = None, options: ChatOptions[ResponseModelBoundT], **kwargs: Any, @@ -1110,7 +1110,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: Literal[False] = ..., - thread: AgentThread | None = None, + session: AgentSession | None = None, middleware: Sequence[MiddlewareTypes] | None = None, options: ChatOptions[None] | None = None, **kwargs: Any, @@ -1122,7 +1122,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: Literal[True], - thread: AgentThread | None = None, + session: AgentSession | None = None, middleware: Sequence[MiddlewareTypes] | None = None, options: ChatOptions[Any] | None = None, **kwargs: Any, @@ -1133,7 +1133,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, middleware: Sequence[MiddlewareTypes] | None = None, options: ChatOptions[Any] | None = None, **kwargs: Any, @@ -1157,12 +1157,12 @@ def run( # Execute with middleware if available if not pipeline.has_middlewares: - return super().run(messages, stream=stream, thread=thread, options=options, **combined_kwargs) # type: ignore[misc, no-any-return] + return super().run(messages, stream=stream, session=session, options=options, **combined_kwargs) # type: ignore[misc, no-any-return] context = AgentContext( agent=self, # type: ignore[arg-type] messages=prepare_messages(messages), # type: ignore[arg-type] - thread=thread, + session=session, options=options, stream=stream, kwargs=combined_kwargs, @@ -1197,7 +1197,7 @@ def _middleware_handler( return super().run( # type: ignore[misc, no-any-return] context.messages, stream=context.stream, - thread=context.thread, + session=context.session, options=context.options, **context.kwargs, ) diff --git a/python/packages/core/agent_framework/_workflows/_agent.py b/python/packages/core/agent_framework/_workflows/_agent.py index 06b3bbb613..961ac42b72 100644 --- a/python/packages/core/agent_framework/_workflows/_agent.py +++ b/python/packages/core/agent_framework/_workflows/_agent.py @@ -14,7 +14,7 @@ from agent_framework import ( AgentResponse, AgentResponseUpdate, - AgentThread, + AgentSession, BaseAgent, Content, Message, @@ -130,7 +130,7 @@ def run( messages: str | Message | list[str] | list[Message] | None = None, *, stream: Literal[True], - thread: AgentThread | None = None, + session: AgentSession | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, **kwargs: Any, @@ -142,7 +142,7 @@ async def run( messages: str | Message | list[str] | list[Message] | None = None, *, stream: Literal[False] = ..., - thread: AgentThread | None = None, + session: AgentSession | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, **kwargs: Any, @@ -153,7 +153,7 @@ def run( messages: str | Message | list[str] | list[Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, **kwargs: Any, @@ -167,7 +167,7 @@ def run( Keyword Args: stream: If True, returns an async iterable of updates. If False (default), returns an awaitable AgentResponse. - thread: The conversation thread. If None, a new thread will be created. + session: The agent session for conversation context. checkpoint_id: ID of checkpoint to restore from. If provided, the workflow resumes from this checkpoint instead of starting fresh. checkpoint_storage: Runtime checkpoint storage. When provided with checkpoint_id, @@ -187,14 +187,14 @@ def run( if stream: return self._run_streaming( messages=messages, - thread=thread, + session=session, checkpoint_id=checkpoint_id, checkpoint_storage=checkpoint_storage, **kwargs, ) return self._run_non_streaming( messages=messages, - thread=thread, + session=session, checkpoint_id=checkpoint_id, checkpoint_storage=checkpoint_storage, **kwargs, @@ -204,30 +204,26 @@ async def _run_non_streaming( self, messages: str | Message | list[str] | list[Message] | None = None, *, - thread: AgentThread | None = None, + session: AgentSession | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, **kwargs: Any, ) -> AgentResponse: """Internal non-streaming implementation.""" input_messages = normalize_messages_input(messages) - thread = thread or self.get_new_thread() response_id = str(uuid.uuid4()) response = await self._run_impl( - input_messages, response_id, thread, checkpoint_id, checkpoint_storage, **kwargs + input_messages, response_id, session, checkpoint_id, checkpoint_storage, **kwargs ) - # Notify thread of new messages (both input and response messages) - await self._notify_thread_of_new_messages(thread, input_messages, response.messages) - return response async def _run_streaming( self, messages: str | Message | list[str] | list[Message] | None = None, *, - thread: AgentThread | None = None, + session: AgentSession | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, **kwargs: Any, @@ -239,27 +235,20 @@ async def _run_streaming( to function call and approval request contents. """ input_messages = normalize_messages_input(messages) - thread = thread or self.get_new_thread() response_updates: list[AgentResponseUpdate] = [] response_id = str(uuid.uuid4()) async for update in self._run_stream_impl( - input_messages, response_id, thread, checkpoint_id, checkpoint_storage, **kwargs + input_messages, response_id, session, checkpoint_id, checkpoint_storage, **kwargs ): response_updates.append(update) yield update - # Convert updates to final response. - response = self.merge_updates(response_updates, response_id) - - # Notify thread of new messages (both input and response messages) - await self._notify_thread_of_new_messages(thread, input_messages, response.messages) - async def _run_impl( self, input_messages: list[Message], response_id: str, - thread: AgentThread, + session: AgentSession | None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, **kwargs: Any, @@ -269,7 +258,7 @@ async def _run_impl( Args: input_messages: Normalized input messages to process. response_id: The unique response ID for this workflow execution. - thread: The conversation thread containing message history. + session: The agent session for conversation context. checkpoint_id: ID of checkpoint to restore from. checkpoint_storage: Runtime checkpoint storage. **kwargs: Additional keyword arguments passed through to the underlying @@ -280,7 +269,7 @@ async def _run_impl( """ output_events: list[WorkflowEvent[Any]] = [] async for event in self._run_core( - input_messages, thread, checkpoint_id, checkpoint_storage, streaming=False, **kwargs + input_messages, checkpoint_id, checkpoint_storage, streaming=False, **kwargs ): if event.type == "output" or event.type == "request_info": output_events.append(event) @@ -291,7 +280,7 @@ async def _run_stream_impl( self, input_messages: list[Message], response_id: str, - thread: AgentThread, + session: AgentSession | None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, **kwargs: Any, @@ -301,7 +290,7 @@ async def _run_stream_impl( Args: input_messages: Normalized input messages to process. response_id: The unique response ID for this workflow execution. - thread: The conversation thread containing message history. + session: The agent session for conversation context. checkpoint_id: ID of checkpoint to restore from. checkpoint_storage: Runtime checkpoint storage. **kwargs: Additional keyword arguments passed through to the underlying @@ -311,7 +300,7 @@ async def _run_stream_impl( AgentResponseUpdate objects representing the workflow execution progress. """ async for event in self._run_core( - input_messages, thread, checkpoint_id, checkpoint_storage, streaming=True, **kwargs + input_messages, checkpoint_id, checkpoint_storage, streaming=True, **kwargs ): updates = self._convert_workflow_event_to_agent_response_updates(response_id, event) for update in updates: @@ -320,7 +309,6 @@ async def _run_stream_impl( async def _run_core( self, input_messages: list[Message], - thread: AgentThread, checkpoint_id: str | None, checkpoint_storage: CheckpointStorage | None, streaming: bool, @@ -330,7 +318,6 @@ async def _run_core( Args: input_messages: Normalized input messages to process. - thread: The conversation thread containing message history. checkpoint_id: ID of checkpoint to restore from. checkpoint_storage: Runtime checkpoint storage. streaming: Whether to use streaming workflow methods. @@ -371,10 +358,9 @@ async def _run_core( yield event else: - conversation_messages = await self._build_conversation_messages(thread, input_messages) if streaming: async for event in self.workflow.run( - message=conversation_messages, + message=input_messages, stream=True, checkpoint_storage=checkpoint_storage, **kwargs, @@ -382,7 +368,7 @@ async def _run_core( yield event else: for event in await self.workflow.run( - message=conversation_messages, + message=input_messages, checkpoint_storage=checkpoint_storage, **kwargs, ): @@ -390,28 +376,6 @@ async def _run_core( # endregion Run Methods - async def _build_conversation_messages( - self, - thread: AgentThread, - input_messages: list[Message], - ) -> list[Message]: - """Build the complete conversation by prepending thread history to input messages. - - Args: - thread: The conversation thread containing message history. - input_messages: The new input messages to append. - - Returns: - A list of Message objects representing the full conversation. - """ - conversation_messages: list[Message] = [] - if thread.message_store: - history = await thread.message_store.list_messages() - if history: - conversation_messages.extend(history) - conversation_messages.extend(input_messages) - return conversation_messages - def _process_pending_requests(self, input_messages: list[Message]) -> dict[str, Any]: """Process pending requests by extracting function responses and updating state. diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index 85bc236982..6252bf8ffe 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -11,10 +11,12 @@ from agent_framework import Content from .._agents import SupportsAgentRun -from .._threads import AgentThread +from .._sessions import AgentSession from .._types import AgentResponse, AgentResponseUpdate, Message from ._agent_utils import resolve_agent_id +from ._checkpoint_encoding import encode_checkpoint_value from ._const import WORKFLOW_RUN_KWARGS_KEY +from ._conversation_state import encode_chat_messages from ._executor import Executor, handler from ._message_utils import normalize_messages_input from ._request_info_mixin import response_handler @@ -81,14 +83,14 @@ def __init__( self, agent: SupportsAgentRun, *, - agent_thread: AgentThread | None = None, + session: AgentSession | None = None, id: str | None = None, ): """Initialize the executor with a unique identifier. Args: agent: The agent to be wrapped by this executor. - agent_thread: The thread to use for running the agent. If None, a new thread will be created. + session: The session to use for running the agent. If None, a new session will be created. id: A unique identifier for the executor. If None, the agent's name will be used if available. """ # Prefer provided id; else use agent.name if present; else generate deterministic prefix @@ -97,7 +99,7 @@ def __init__( raise ValueError("Agent must have a non-empty name or id or an explicit id must be provided.") super().__init__(exec_id) self._agent = agent - self._agent_thread = agent_thread or self._agent.get_new_thread() + self._session = session or self._agent.create_session() self._pending_agent_requests: dict[str, Content] = {} self._pending_responses_to_agent: list[Content] = [] @@ -205,37 +207,35 @@ async def handle_user_input_response( async def on_checkpoint_save(self) -> dict[str, Any]: """Capture current executor state for checkpointing. - NOTE: if the thread storage is on the server side, the full thread state - may not be serialized locally. Therefore, we are relying on the server-side - to ensure the thread state is preserved and immutable across checkpoints. - This is not the case for AzureAI Agents, but works for the Responses API. + NOTE: if the session uses service-side storage, the full session state + may not be serialized locally. Returns: - Dict containing serialized cache and thread state + Dict containing serialized cache and session state """ - # Check if using AzureAIAgentClient with server-side thread and warn about checkpointing limitations - if is_chat_agent(self._agent) and self._agent_thread.service_thread_id is not None: + # Check if using AzureAIAgentClient with server-side session and warn about checkpointing limitations + if is_chat_agent(self._agent) and self._session.service_session_id is not None: client_class_name = self._agent.client.__class__.__name__ client_module = self._agent.client.__class__.__module__ if client_class_name == "AzureAIAgentClient" and "azure_ai" in client_module: logger.warning( - "Checkpointing an AgentExecutor with AzureAIAgentClient that uses server-side threads. " - "Currently, checkpointing does not capture messages from server-side threads " - "(service_thread_id: %s). The thread state in checkpoints is not immutable and can be " + "Checkpointing an AgentExecutor with AzureAIAgentClient that uses server-side sessions. " + "Currently, checkpointing does not capture messages from server-side sessions " + "(service_session_id: %s). The session state in checkpoints is not immutable and can be " "modified by subsequent runs. If you need reliable checkpointing with Azure AI agents, " - "consider implementing a custom executor and managing the thread state yourself.", - self._agent_thread.service_thread_id, + "consider implementing a custom executor and managing the session state yourself.", + self._session.service_session_id, ) - serialized_thread = await self._agent_thread.serialize() + serialized_session = self._session.to_dict() return { - "cache": self._cache, - "full_conversation": self._full_conversation, - "agent_thread": serialized_thread, - "pending_agent_requests": self._pending_agent_requests, - "pending_responses_to_agent": self._pending_responses_to_agent, + "cache": encode_chat_messages(self._cache), + "full_conversation": encode_chat_messages(self._full_conversation), + "agent_session": serialized_session, + "pending_agent_requests": encode_checkpoint_value(self._pending_agent_requests), + "pending_responses_to_agent": encode_checkpoint_value(self._pending_responses_to_agent), } @override @@ -251,17 +251,15 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: full_conversation_payload = state.get("full_conversation") self._full_conversation = full_conversation_payload or [] - thread_payload = state.get("agent_thread") - if thread_payload: + session_payload = state.get("agent_session") + if session_payload: try: - # Deserialize the thread state directly - self._agent_thread = await AgentThread.deserialize(thread_payload) - + self._session = AgentSession.from_dict(session_payload) except Exception as exc: - logger.warning("Failed to restore agent thread: %s", exc) - self._agent_thread = self._agent.get_new_thread() + logger.warning("Failed to restore agent session: %s", exc) + self._session = self._agent.create_session() else: - self._agent_thread = self._agent.get_new_thread() + self._session = self._agent.create_session() pending_requests_payload = state.get("pending_agent_requests") if pending_requests_payload: @@ -321,7 +319,7 @@ async def _run_agent(self, ctx: WorkflowContext[Never, AgentResponse]) -> AgentR response = await self._agent.run( self._cache, stream=False, - thread=self._agent_thread, + session=self._session, options=options, **run_kwargs, ) @@ -352,7 +350,7 @@ async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUp async for update in self._agent.run( self._cache, stream=True, - thread=self._agent_thread, + session=self._session, options=options, **run_kwargs, ): diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index 64ceefe673..374c0361cf 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -39,7 +39,7 @@ from ._agents import SupportsAgentRun from ._clients import SupportsChatGetResponse - from ._threads import AgentThread + from ._sessions import AgentSession from ._tools import FunctionTool from ._types import ( AgentResponse, @@ -1280,7 +1280,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: Literal[False] = ..., - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]]: ... @@ -1290,7 +1290,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: Literal[True], - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... @@ -1299,7 +1299,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: """Trace agent runs with OpenTelemetry spans and metrics.""" @@ -1312,7 +1312,7 @@ def run( return super_run( # type: ignore[no-any-return] messages=messages, stream=stream, - thread=thread, + session=session, **kwargs, ) @@ -1327,7 +1327,7 @@ def run( agent_id=getattr(self, "id", "unknown"), agent_name=getattr(self, "name", None) or getattr(self, "id", "unknown"), agent_description=getattr(self, "description", None), - thread_id=thread.service_thread_id if thread else None, + thread_id=session.service_session_id if session else None, all_options=merged_options, **kwargs, ) @@ -1336,7 +1336,7 @@ def run( run_result = super_run( messages=messages, stream=True, - thread=thread, + session=session, **kwargs, ) if isinstance(run_result, ResponseStream): @@ -1423,7 +1423,7 @@ async def _run() -> AgentResponse: response = await super_run( messages=messages, stream=False, - thread=thread, + session=session, **kwargs, ) except Exception as exception: diff --git a/python/packages/core/agent_framework/openai/_assistant_provider.py b/python/packages/core/agent_framework/openai/_assistant_provider.py index 90820ec5d2..a64ae87b95 100644 --- a/python/packages/core/agent_framework/openai/_assistant_provider.py +++ b/python/packages/core/agent_framework/openai/_assistant_provider.py @@ -13,7 +13,7 @@ from agent_framework._settings import SecretString, load_settings from .._agents import Agent -from .._memory import ContextProvider +from .._sessions import BaseContextProvider from .._middleware import MiddlewareTypes from .._tools import FunctionTool from .._types import normalize_tools @@ -208,7 +208,7 @@ async def create_agent( metadata: dict[str, str] | None = None, default_options: OptionsCoT | None = None, middleware: Sequence[MiddlewareTypes] | None = None, - context_provider: ContextProvider | None = None, + context_providers: Sequence[BaseContextProvider] | None = None, ) -> Agent[OptionsCoT]: """Create a new assistant on OpenAI and return a Agent. @@ -230,7 +230,7 @@ async def create_agent( These options are applied to every run unless overridden. Include ``response_format`` here for structured output responses. middleware: MiddlewareTypes for the Agent. - context_provider: Context provider for the Agent. + context_providers: Context providers for the Agent. Returns: A Agent instance wrapping the created assistant. @@ -304,7 +304,7 @@ async def create_agent( tools=normalized_tools, instructions=instructions, middleware=middleware, - context_provider=context_provider, + context_providers=context_providers, default_options=default_options, ) @@ -316,7 +316,7 @@ async def get_agent( instructions: str | None = None, default_options: OptionsCoT | None = None, middleware: Sequence[MiddlewareTypes] | None = None, - context_provider: ContextProvider | None = None, + context_providers: Sequence[BaseContextProvider] | None = None, ) -> Agent[OptionsCoT]: """Retrieve an existing assistant by ID and return a Agent. @@ -335,7 +335,7 @@ async def get_agent( default_options: A TypedDict containing default chat options for the agent. These options are applied to every run unless overridden. middleware: MiddlewareTypes for the Agent. - context_provider: Context provider for the Agent. + context_providers: Context providers for the Agent. Returns: A Agent instance wrapping the retrieved assistant. @@ -371,7 +371,7 @@ async def get_agent( instructions=instructions, default_options=default_options, middleware=middleware, - context_provider=context_provider, + context_providers=context_providers, ) def as_agent( @@ -382,7 +382,7 @@ def as_agent( instructions: str | None = None, default_options: OptionsCoT | None = None, middleware: Sequence[MiddlewareTypes] | None = None, - context_provider: ContextProvider | None = None, + context_providers: Sequence[BaseContextProvider] | None = None, ) -> Agent[OptionsCoT]: """Wrap an existing SDK Assistant object as a Agent. @@ -400,7 +400,7 @@ def as_agent( default_options: A TypedDict containing default chat options for the agent. These options are applied to every run unless overridden. middleware: MiddlewareTypes for the Agent. - context_provider: Context provider for the Agent. + context_providers: Context providers for the Agent. Returns: A Agent instance wrapping the assistant. @@ -437,7 +437,7 @@ def as_agent( instructions=instructions, default_options=default_options, middleware=middleware, - context_provider=context_provider, + context_providers=context_providers, ) def _validate_function_tools( @@ -524,7 +524,7 @@ def _create_chat_agent_from_assistant( tools: list[FunctionTool | MutableMapping[str, Any]] | None, instructions: str | None, middleware: Sequence[MiddlewareTypes] | None, - context_provider: ContextProvider | None, + context_providers: Sequence[BaseContextProvider] | None, default_options: OptionsCoT | None = None, **kwargs: Any, ) -> Agent[OptionsCoT]: @@ -535,7 +535,7 @@ def _create_chat_agent_from_assistant( tools: Tools for the agent. instructions: Instructions override. middleware: MiddlewareTypes for the agent. - context_provider: Context provider for the agent. + context_providers: Context providers for the agent. default_options: Default chat options for the agent (may include response_format). **kwargs: Additional arguments passed to Agent. @@ -563,7 +563,7 @@ def _create_chat_agent_from_assistant( instructions=final_instructions, tools=tools if tools else None, middleware=middleware, - context_provider=context_provider, + context_providers=context_providers, default_options=default_options, # type: ignore[arg-type] **kwargs, ) diff --git a/python/packages/devui/agent_framework_devui/_conversations.py b/python/packages/devui/agent_framework_devui/_conversations.py index 2ea28f6e6a..51f42a2389 100644 --- a/python/packages/devui/agent_framework_devui/_conversations.py +++ b/python/packages/devui/agent_framework_devui/_conversations.py @@ -13,7 +13,7 @@ from abc import ABC, abstractmethod from typing import Any, Literal, cast -from agent_framework import AgentThread, Message +from agent_framework import AgentSession, AgentThread, Message from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage from openai.types.conversations import Conversation, ConversationDeletedResource from openai.types.conversations.conversation_item import ConversationItem @@ -152,17 +152,17 @@ async def get_item(self, conversation_id: str, item_id: str) -> ConversationItem pass @abstractmethod - def get_thread(self, conversation_id: str) -> AgentThread | None: - """Get underlying AgentThread for execution (internal use). + def get_session(self, conversation_id: str) -> AgentSession | None: + """Get AgentSession for agent execution. This is the critical method that allows the executor to get the - AgentThread for running agents with conversation context. + AgentSession for running agents with conversation context. Args: conversation_id: Conversation ID Returns: - AgentThread object or None if not found + AgentSession object or None if not found """ pass @@ -229,8 +229,9 @@ def create_conversation( conv_id = conversation_id or f"conv_{uuid.uuid4().hex}" created_at = int(time.time()) - # Create AgentThread with default ChatMessageStore + # Create AgentThread for internal message storage and AgentSession for execution thread = AgentThread() + session = AgentSession(session_id=conv_id) # Create session-scoped checkpoint storage (one per conversation) checkpoint_storage = InMemoryCheckpointStorage() @@ -238,6 +239,7 @@ def create_conversation( self._conversations[conv_id] = { "id": conv_id, "thread": thread, + "session": session, "checkpoint_storage": checkpoint_storage, # Stored alongside thread "metadata": metadata or {}, "created_at": created_at, @@ -589,10 +591,10 @@ async def get_item(self, conversation_id: str, item_id: str) -> ConversationItem return None - def get_thread(self, conversation_id: str) -> AgentThread | None: - """Get AgentThread for execution - CRITICAL for agent.run().""" + def get_session(self, conversation_id: str) -> AgentSession | None: + """Get AgentSession for execution - CRITICAL for agent.run().""" conv_data = self._conversations.get(conversation_id) - return conv_data["thread"] if conv_data else None + return conv_data["session"] if conv_data else None def add_trace(self, conversation_id: str, trace_event: dict[str, Any]) -> None: """Add a trace event to the conversation for context inspection. diff --git a/python/packages/devui/agent_framework_devui/_executor.py b/python/packages/devui/agent_framework_devui/_executor.py index 92e6301b66..e019917630 100644 --- a/python/packages/devui/agent_framework_devui/_executor.py +++ b/python/packages/devui/agent_framework_devui/_executor.py @@ -308,15 +308,15 @@ async def _execute_agent( # Convert input to proper Message or string user_message = self._convert_input_to_chat_message(request.input) - # Get thread from conversation parameter (OpenAI standard!) - thread = None + # Get session from conversation parameter (OpenAI standard!) + session = None conversation_id = request._get_conversation_id() if conversation_id: - thread = self.conversation_store.get_thread(conversation_id) - if thread: + session = self.conversation_store.get_session(conversation_id) + if session: logger.debug(f"Using existing conversation: {conversation_id}") else: - logger.warning(f"Conversation {conversation_id} not found, proceeding without thread") + logger.warning(f"Conversation {conversation_id} not found, proceeding without session") if isinstance(user_message, str): logger.debug(f"Executing agent with text input: {user_message[:100]}...") @@ -331,8 +331,8 @@ async def _execute_agent( # Agent must have run() method - use stream=True for streaming if hasattr(agent, "run") and callable(agent.run): # Use Agent Framework's run() with stream=True for streaming - if thread: - async for update in agent.run(user_message, stream=True, thread=thread): + if session: + async for update in agent.run(user_message, stream=True, session=session): for trace_event in trace_collector.get_pending_events(): yield trace_event diff --git a/python/packages/devui/tests/devui/test_conversations.py b/python/packages/devui/tests/devui/test_conversations.py index ccaea3524c..812e0e718f 100644 --- a/python/packages/devui/tests/devui/test_conversations.py +++ b/python/packages/devui/tests/devui/test_conversations.py @@ -83,29 +83,29 @@ async def test_delete_conversation(): @pytest.mark.asyncio -async def test_get_thread(): - """Test getting underlying AgentThread.""" +async def test_get_session(): + """Test getting AgentSession for execution.""" store = InMemoryConversationStore() # Create conversation conversation = store.create_conversation(metadata={"agent_id": "test_agent"}) - # Get thread - thread = store.get_thread(conversation.id) + # Get session + session = store.get_session(conversation.id) - assert thread is not None - # AgentThread should have message_store - assert hasattr(thread, "message_store") + assert session is not None + # AgentSession should have session_id + assert hasattr(session, "session_id") @pytest.mark.asyncio -async def test_get_thread_not_found(): - """Test getting thread for non-existent conversation.""" +async def test_get_session_not_found(): + """Test getting session for non-existent conversation.""" store = InMemoryConversationStore() - thread = store.get_thread("conv_nonexistent") + session = store.get_session("conv_nonexistent") - assert thread is None + assert session is None @pytest.mark.asyncio @@ -206,8 +206,8 @@ async def test_list_items_converts_function_calls(): # Create conversation conversation = store.create_conversation(metadata={"agent_id": "test_agent"}) - # Get the underlying thread and set up message store - thread = store.get_thread(conversation.id) + # Get the underlying thread for internal message store setup + thread = store._conversations[conversation.id]["thread"] assert thread is not None # Initialize message store if not present @@ -291,8 +291,8 @@ async def test_list_items_handles_images_and_files(): # Create conversation conversation = store.create_conversation(metadata={"agent_id": "test_agent"}) - # Get the underlying thread - thread = store.get_thread(conversation.id) + # Get the underlying thread for internal message store setup + thread = store._conversations[conversation.id]["thread"] assert thread is not None if thread.message_store is None: diff --git a/python/packages/durabletask/AGENTS.md b/python/packages/durabletask/AGENTS.md index 905f462212..6e185bcd98 100644 --- a/python/packages/durabletask/AGENTS.md +++ b/python/packages/durabletask/AGENTS.md @@ -18,7 +18,7 @@ Durable execution support for long-running agent workflows using Azure Durable F ### State Management - **`DurableAgentState`** - State container for durable agents -- **`DurableAgentThread`** - Thread management for durable agents +- **`DurableAgentSession`** - Session management for durable agents - **`DurableAIAgentOrchestrationContext`** - Orchestration context ### Callbacks diff --git a/python/packages/durabletask/agent_framework_durabletask/__init__.py b/python/packages/durabletask/agent_framework_durabletask/__init__.py index 84a1361d9a..bb0da56af4 100644 --- a/python/packages/durabletask/agent_framework_durabletask/__init__.py +++ b/python/packages/durabletask/agent_framework_durabletask/__init__.py @@ -45,7 +45,7 @@ ) from ._entities import AgentEntity, AgentEntityStateProviderMixin from ._executors import DurableAgentExecutor -from ._models import AgentSessionId, DurableAgentThread, RunRequest +from ._models import AgentSessionId, DurableAgentSession, RunRequest from ._orchestration_context import DurableAIAgentOrchestrationContext from ._response_utils import ensure_response_format, load_agent_response from ._shim import DurableAIAgent @@ -99,7 +99,7 @@ "DurableAgentStateUriContent", "DurableAgentStateUsage", "DurableAgentStateUsageContent", - "DurableAgentThread", + "DurableAgentSession", "DurableStateFields", "RunRequest", "__version__", diff --git a/python/packages/durabletask/agent_framework_durabletask/_executors.py b/python/packages/durabletask/agent_framework_durabletask/_executors.py index 2193f94e16..7adb79875a 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_executors.py +++ b/python/packages/durabletask/agent_framework_durabletask/_executors.py @@ -16,7 +16,7 @@ from datetime import datetime, timezone from typing import Any, Generic, TypeVar -from agent_framework import AgentResponse, AgentThread, Content, Message, get_logger +from agent_framework import AgentResponse, AgentSession, Content, Message, get_logger from durabletask.client import TaskHubGrpcClient from durabletask.entities import EntityInstanceId from durabletask.task import CompletableTask, CompositeTask, OrchestrationContext, Task @@ -24,7 +24,7 @@ from ._constants import DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS from ._durable_agent_state import DurableAgentState -from ._models import AgentSessionId, DurableAgentThread, RunRequest +from ._models import AgentSessionId, DurableAgentSession, RunRequest from ._response_utils import ensure_response_format, load_agent_response logger = get_logger("agent_framework.durabletask.executors") @@ -114,7 +114,7 @@ def run_durable_agent( self, agent_name: str, run_request: RunRequest, - thread: AgentThread | None = None, + session: AgentSession | None = None, ) -> TaskT: """Execute the durable agent. @@ -123,20 +123,20 @@ def run_durable_agent( """ raise NotImplementedError - def get_new_thread(self, agent_name: str, **kwargs: Any) -> DurableAgentThread: - """Create a new DurableAgentThread with random session ID.""" + def get_new_session(self, agent_name: str, **kwargs: Any) -> DurableAgentSession: + """Create a new DurableAgentSession with random session ID.""" session_id = self._create_session_id(agent_name) - return DurableAgentThread.from_session_id(session_id, **kwargs) + return DurableAgentSession.from_session_id(session_id, **kwargs) def _create_session_id( self, agent_name: str, - thread: AgentThread | None = None, + session: AgentSession | None = None, ) -> AgentSessionId: """Create the AgentSessionId for the execution.""" - if isinstance(thread, DurableAgentThread) and thread.session_id is not None: - return thread.session_id - # Create new session ID - either no thread provided or it's a regular AgentThread + if isinstance(session, DurableAgentSession) and session.durable_session_id is not None: + return session.durable_session_id + # Create new session ID - either no session provided or it's a regular AgentSession key = self.generate_unique_id() return AgentSessionId(name=agent_name, key=key) @@ -217,7 +217,7 @@ def run_durable_agent( self, agent_name: str, run_request: RunRequest, - thread: AgentThread | None = None, + session: AgentSession | None = None, ) -> AgentResponse: """Execute the agent via the durabletask client. @@ -231,14 +231,14 @@ def run_durable_agent( Args: agent_name: Name of the agent to execute run_request: The run request containing message and optional response format - thread: Optional conversation thread (creates new if not provided) + session: Optional conversation session (creates new if not provided) Returns: AgentResponse: The agent's response after execution completes, or an immediate acknowledgement if wait_for_response is False """ # Signal the entity with the request - entity_id = self._signal_agent_entity(agent_name, run_request, thread) + entity_id = self._signal_agent_entity(agent_name, run_request, session) # If fire-and-forget mode, return immediately without polling if not run_request.wait_for_response: @@ -258,20 +258,20 @@ def _signal_agent_entity( self, agent_name: str, run_request: RunRequest, - thread: AgentThread | None, + session: AgentSession | None, ) -> EntityInstanceId: """Signal the agent entity with a run request. Args: agent_name: Name of the agent to execute run_request: The run request containing message and optional response format - thread: Optional conversation thread + session: Optional conversation session Returns: entity_id """ # Get or create session ID - session_id = self._create_session_id(agent_name, thread) + session_id = self._create_session_id(agent_name, session) # Create the entity ID entity_id = EntityInstanceId( @@ -460,7 +460,7 @@ def run_durable_agent( self, agent_name: str, run_request: RunRequest, - thread: AgentThread | None = None, + session: AgentSession | None = None, ) -> DurableAgentTask: """Execute the agent via orchestration context. @@ -470,13 +470,13 @@ def run_durable_agent( Args: agent_name: Name of the agent to execute run_request: The run request containing message and optional response format - thread: Optional conversation thread (creates new if not provided) + session: Optional conversation session (creates new if not provided) Returns: DurableAgentTask: A task wrapping the entity call that yields AgentResponse """ # Resolve session - session_id = self._create_session_id(agent_name, thread) + session_id = self._create_session_id(agent_name, session) # Create the entity ID entity_id = EntityInstanceId( diff --git a/python/packages/durabletask/agent_framework_durabletask/_models.py b/python/packages/durabletask/agent_framework_durabletask/_models.py index 3d20828fc7..1c5484afbf 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_models.py +++ b/python/packages/durabletask/agent_framework_durabletask/_models.py @@ -10,13 +10,12 @@ import inspect import json import uuid -from collections.abc import MutableMapping from dataclasses import dataclass, field from datetime import datetime, timezone from importlib import import_module from typing import TYPE_CHECKING, Any, cast -from agent_framework import AgentThread +from agent_framework import AgentSession from ._constants import REQUEST_RESPONSE_FORMAT_TEXT @@ -274,65 +273,57 @@ def parse(session_id_string: str, agent_name: str | None = None) -> AgentSession raise ValueError(f"Invalid agent session ID format: {session_id_string}") -class DurableAgentThread(AgentThread): - """Durable agent thread that tracks the owning :class:`AgentSessionId`.""" +class DurableAgentSession(AgentSession): + """Durable agent session that tracks the owning :class:`AgentSessionId`.""" _SERIALIZED_SESSION_ID_KEY = "durable_session_id" def __init__( self, *, - session_id: AgentSessionId | None = None, + durable_session_id: AgentSessionId | None = None, + session_id: str | None = None, + service_session_id: str | None = None, **kwargs: Any, ) -> None: - super().__init__(**kwargs) - self._session_id: AgentSessionId | None = session_id + super().__init__(session_id=session_id, service_session_id=service_session_id, **kwargs) + self._session_id_value: AgentSessionId | None = durable_session_id @property - def session_id(self) -> AgentSessionId | None: - return self._session_id + def durable_session_id(self) -> AgentSessionId | None: + return self._session_id_value - @session_id.setter - def session_id(self, value: AgentSessionId | None) -> None: - self._session_id = value + @durable_session_id.setter + def durable_session_id(self, value: AgentSessionId | None) -> None: + self._session_id_value = value @classmethod def from_session_id( cls, session_id: AgentSessionId, **kwargs: Any, - ) -> DurableAgentThread: - return cls(session_id=session_id, **kwargs) + ) -> DurableAgentSession: + return cls(durable_session_id=session_id, **kwargs) - async def serialize(self, **kwargs: Any) -> dict[str, Any]: - state = await super().serialize(**kwargs) - if self._session_id is not None: - state[self._SERIALIZED_SESSION_ID_KEY] = str(self._session_id) + def to_dict(self) -> dict[str, Any]: + state = super().to_dict() + if self._session_id_value is not None: + state[self._SERIALIZED_SESSION_ID_KEY] = str(self._session_id_value) return state @classmethod - async def deserialize( - cls, - serialized_thread_state: MutableMapping[str, Any], - *, - message_store: Any = None, - **kwargs: Any, - ) -> DurableAgentThread: - state_payload = dict(serialized_thread_state) + def from_dict(cls, data: dict[str, Any]) -> DurableAgentSession: + state_payload = dict(data) session_id_value = state_payload.pop(cls._SERIALIZED_SESSION_ID_KEY, None) - thread = await super().deserialize( - state_payload, - message_store=message_store, - **kwargs, + session = super().from_dict(state_payload) + # We need to create a DurableAgentSession from the base AgentSession + durable_session = cls( + session_id=session.session_id, + service_session_id=session.service_session_id, ) - if not isinstance(thread, DurableAgentThread): - raise TypeError("Deserialized thread is not a DurableAgentThread instance") - - if session_id_value is None: - return thread - - if not isinstance(session_id_value, str): - raise ValueError("durable_session_id must be a string when present in serialized state") - - thread.session_id = AgentSessionId.parse(session_id_value) - return thread + durable_session.state.update(session.state) + if session_id_value is not None: + if not isinstance(session_id_value, str): + raise ValueError("durable_session_id must be a string when present in serialized state") + durable_session._session_id_value = AgentSessionId.parse(session_id_value) + return durable_session diff --git a/python/packages/durabletask/agent_framework_durabletask/_shim.py b/python/packages/durabletask/agent_framework_durabletask/_shim.py index 19ea8a496f..8ad40e34b6 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_shim.py +++ b/python/packages/durabletask/agent_framework_durabletask/_shim.py @@ -12,10 +12,10 @@ from abc import ABC, abstractmethod from typing import Any, Generic, Literal, TypeVar -from agent_framework import AgentThread, Message, SupportsAgentRun +from agent_framework import AgentSession, Message, SupportsAgentRun from ._executors import DurableAgentExecutor -from ._models import DurableAgentThread +from ._models import DurableAgentSession # TypeVar for the task type returned by executors # Covariant because TaskT only appears in return positions (output) @@ -89,7 +89,7 @@ def run( # type: ignore[override] messages: str | Message | list[str] | list[Message] | None = None, *, stream: Literal[False] = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, options: dict[str, Any] | None = None, ) -> TaskT: """Execute the agent via the injected provider. @@ -98,7 +98,7 @@ def run( # type: ignore[override] messages: The message(s) to send to the agent stream: Whether to use streaming for the response (must be False) DurableAgents do not support streaming mode. - thread: Optional agent thread for conversation context + session: Optional agent session for conversation context options: Optional options dictionary. Supported keys include ``response_format``, ``enable_tool_calls``, and ``wait_for_response``. Additional keys are forwarded to the agent execution. @@ -129,12 +129,12 @@ def run( # type: ignore[override] return self._executor.run_durable_agent( agent_name=self.name, run_request=run_request, - thread=thread, + session=session, ) - def get_new_thread(self, **kwargs: Any) -> DurableAgentThread: - """Create a new agent thread via the provider.""" - return self._executor.get_new_thread(self.name, **kwargs) + def create_session(self, **kwargs: Any) -> DurableAgentSession: + """Create a new agent session via the provider.""" + return self._executor.get_new_session(self.name, **kwargs) def _normalize_messages(self, messages: str | Message | list[str] | list[Message] | None) -> str: """Convert supported message inputs to a single string. diff --git a/python/packages/durabletask/tests/integration_tests/test_01_dt_single_agent.py b/python/packages/durabletask/tests/integration_tests/test_01_dt_single_agent.py index b87e078345..43795f9ef1 100644 --- a/python/packages/durabletask/tests/integration_tests/test_01_dt_single_agent.py +++ b/python/packages/durabletask/tests/integration_tests/test_01_dt_single_agent.py @@ -39,9 +39,9 @@ def test_agent_registration(self) -> None: def test_single_interaction(self): """Test a single interaction with the agent.""" agent = self.agent_client.get_agent("Joker") - thread = agent.get_new_thread() + session = agent.create_session() - response = agent.run("Tell me a short joke about programming.", thread=thread) + response = agent.run("Tell me a short joke about programming.", session=session) assert response is not None assert response.text is not None @@ -50,33 +50,33 @@ def test_single_interaction(self): def test_conversation_continuity(self): """Test that conversation context is maintained across turns.""" agent = self.agent_client.get_agent("Joker") - thread = agent.get_new_thread() + session = agent.create_session() # First turn: Ask for a joke about a specific topic - response1 = agent.run("Tell me a joke about cats.", thread=thread) + response1 = agent.run("Tell me a joke about cats.", session=session) assert response1 is not None assert len(response1.text) > 0 # Second turn: Ask a follow-up that requires context - response2 = agent.run("Can you make it funnier?", thread=thread) + response2 = agent.run("Can you make it funnier?", session=session) assert response2 is not None assert len(response2.text) > 0 # The agent should understand "it" refers to the previous joke - def test_multiple_threads(self): - """Test that different threads maintain separate contexts.""" + def test_multiple_sessions(self): + """Test that different sessions maintain separate contexts.""" agent = self.agent_client.get_agent("Joker") - # Create two separate threads - thread1 = agent.get_new_thread() - thread2 = agent.get_new_thread() + # Create two separate sessions + session1 = agent.create_session() + session2 = agent.create_session() - assert thread1.session_id != thread2.session_id + assert session1.durable_session_id != session2.durable_session_id - # Send different messages to each thread - response1 = agent.run("Tell me a joke about dogs.", thread=thread1) - response2 = agent.run("Tell me a joke about birds.", thread=thread2) + # Send different messages to each session + response1 = agent.run("Tell me a joke about dogs.", session=session1) + response2 = agent.run("Tell me a joke about birds.", session=session2) assert response1 is not None assert response2 is not None diff --git a/python/packages/durabletask/tests/integration_tests/test_02_dt_multi_agent.py b/python/packages/durabletask/tests/integration_tests/test_02_dt_multi_agent.py index 02bcd3029a..9d7d8588ac 100644 --- a/python/packages/durabletask/tests/integration_tests/test_02_dt_multi_agent.py +++ b/python/packages/durabletask/tests/integration_tests/test_02_dt_multi_agent.py @@ -47,9 +47,9 @@ def test_multiple_agents_registered(self) -> None: def test_weather_agent_with_tool(self): """Test weather agent with weather tool execution.""" agent = self.agent_client.get_agent(WEATHER_AGENT_NAME) - thread = agent.get_new_thread() + session = agent.create_session() - response = agent.run("What's the weather in Seattle?", thread=thread) + response = agent.run("What's the weather in Seattle?", session=session) assert response is not None assert response.text is not None @@ -66,9 +66,9 @@ def test_weather_agent_with_tool(self): def test_math_agent_with_tool(self): """Test math agent with calculation tool execution.""" agent = self.agent_client.get_agent(MATH_AGENT_NAME) - thread = agent.get_new_thread() + session = agent.create_session() - response = agent.run("Calculate a 20% tip on a $50 bill.", thread=thread) + response = agent.run("Calculate a 20% tip on a $50 bill.", session=session) assert response is not None assert response.text is not None @@ -85,11 +85,11 @@ def test_math_agent_with_tool(self): def test_multiple_calls_to_same_agent(self): """Test multiple sequential calls to the same agent.""" agent = self.agent_client.get_agent(WEATHER_AGENT_NAME) - thread = agent.get_new_thread() + session = agent.create_session() # Multiple weather queries - response1 = agent.run("What's the weather in Chicago?", thread=thread) - response2 = agent.run("And what about Los Angeles?", thread=thread) + response1 = agent.run("What's the weather in Chicago?", session=session) + response2 = agent.run("And what about Los Angeles?", session=session) assert response1 is not None assert response2 is not None diff --git a/python/packages/durabletask/tests/integration_tests/test_03_dt_single_agent_streaming.py b/python/packages/durabletask/tests/integration_tests/test_03_dt_single_agent_streaming.py index a05c81b2f8..41e8bf15bb 100644 --- a/python/packages/durabletask/tests/integration_tests/test_03_dt_single_agent_streaming.py +++ b/python/packages/durabletask/tests/integration_tests/test_03_dt_single_agent_streaming.py @@ -70,7 +70,7 @@ async def _get_stream_handler(self) -> RedisStreamResponseHandler: # type: igno async def _stream_from_redis( self, - thread_id: str, + session_key: str, cursor: str | None = None, timeout: float = 30.0, ) -> tuple[str, bool, str]: @@ -78,7 +78,7 @@ async def _stream_from_redis( Stream responses from Redis using the sample's RedisStreamResponseHandler. Args: - thread_id: The conversation/thread ID to stream from + session_key: The conversation/thread ID to stream from cursor: Optional cursor to resume from timeout: Maximum time to wait for stream completion @@ -92,7 +92,7 @@ async def _stream_from_redis( async with await self._get_stream_handler() as stream_handler: # type: ignore[reportUnknownMemberType] try: - async for chunk in stream_handler.read_stream(thread_id, cursor): # type: ignore[reportUnknownMemberType] + async for chunk in stream_handler.read_stream(session_key, cursor): # type: ignore[reportUnknownMemberType] if time.time() - start_time > timeout: break @@ -124,15 +124,15 @@ def test_agent_run_and_stream(self) -> None: assert travel_planner is not None assert travel_planner.name == "TravelPlanner" - # Create a new thread - thread = travel_planner.get_new_thread() - assert thread.session_id is not None - assert thread.session_id.key is not None - thread_id = str(thread.session_id.key) + # Create a new session + session = travel_planner.create_session() + assert session.durable_session_id is not None + assert session.durable_session_id.key is not None + session_key = str(session.durable_session_id.key) # Start agent run with wait_for_response=False for non-blocking execution travel_planner.run( - "Plan a 1-day trip to Seattle in 1 sentence", thread=thread, options={"wait_for_response": False} + "Plan a 1-day trip to Seattle in 1 sentence", session=session, options={"wait_for_response": False} ) # Poll Redis stream with retries to handle race conditions @@ -146,7 +146,7 @@ def test_agent_run_and_stream(self) -> None: while retry_count < max_retries and not is_complete: text, is_complete, last_cursor = asyncio.run( - self._stream_from_redis(thread_id, cursor=cursor, timeout=10.0) + self._stream_from_redis(session_key, cursor=cursor, timeout=10.0) ) accumulated_text += text cursor = last_cursor # Resume from last position on next read @@ -166,7 +166,7 @@ def test_agent_run_and_stream(self) -> None: # Verify we got content assert len(accumulated_text) > 0, ( - f"Expected text content but got empty string for thread_id: {thread_id} after {retry_count} retries" + f"Expected text content but got empty string for session_key: {session_key} after {retry_count} retries" ) assert "seattle" in accumulated_text.lower(), f"Expected 'seattle' in response but got: {accumulated_text}" assert is_complete, "Expected stream to be complete" @@ -175,13 +175,13 @@ def test_stream_with_cursor_resumption(self) -> None: """Test streaming with cursor-based resumption.""" # Get the TravelPlanner agent travel_planner = self.agent_client.get_agent("TravelPlanner") - thread = travel_planner.get_new_thread() - assert thread.session_id is not None - assert thread.session_id.key is not None - thread_id = str(thread.session_id.key) + session = travel_planner.create_session() + assert session.durable_session_id is not None + assert session.durable_session_id.key is not None + session_key = str(session.durable_session_id.key) # Start agent run - travel_planner.run("What's the weather like?", thread=thread, options={"wait_for_response": False}) + travel_planner.run("What's the weather like?", session=session, options={"wait_for_response": False}) # Wait for agent to start writing time.sleep(3) @@ -194,7 +194,7 @@ async def get_partial_stream() -> tuple[str, str]: chunk_count = 0 # Read just first 2 chunks - async for chunk in stream_handler.read_stream(thread_id): # type: ignore[reportUnknownMemberType] + async for chunk in stream_handler.read_stream(session_key): # type: ignore[reportUnknownMemberType] last_entry_id = chunk.entry_id # type: ignore[reportUnknownMemberType] if chunk.text: # type: ignore[reportUnknownMemberType] accumulated_text += chunk.text # type: ignore[reportUnknownMemberType] @@ -207,7 +207,7 @@ async def get_partial_stream() -> tuple[str, str]: partial_text, cursor = asyncio.run(get_partial_stream()) # Resume from cursor - remaining_text, _, _ = asyncio.run(self._stream_from_redis(thread_id, cursor=cursor)) + remaining_text, _, _ = asyncio.run(self._stream_from_redis(session_key, cursor=cursor)) # Verify we got some initial content assert len(partial_text) > 0 diff --git a/python/packages/durabletask/tests/test_agent_session_id.py b/python/packages/durabletask/tests/test_agent_session_id.py index 5481e0109d..571212f145 100644 --- a/python/packages/durabletask/tests/test_agent_session_id.py +++ b/python/packages/durabletask/tests/test_agent_session_id.py @@ -1,11 +1,11 @@ # Copyright (c) Microsoft. All rights reserved. -"""Unit tests for AgentSessionId and DurableAgentThread.""" +"""Unit tests for AgentSessionId and DurableAgentSession.""" import pytest -from agent_framework import AgentThread +from agent_framework import AgentSession -from agent_framework_durabletask._models import AgentSessionId, DurableAgentThread +from agent_framework_durabletask._models import AgentSessionId, DurableAgentSession class TestAgentSessionId: @@ -121,154 +121,162 @@ def test_parse_plain_string_without_agent_name_raises(self) -> None: assert "Invalid agent session ID format" in str(exc_info.value) -class TestDurableAgentThread: - """Test suite for DurableAgentThread.""" +class TestDurableAgentSession: + """Test suite for DurableAgentSession.""" - def test_init_with_session_id(self) -> None: - """Test DurableAgentThread initialization with session ID.""" + def test_init_with_durable_session_id(self) -> None: + """Test DurableAgentSession initialization with durable session ID.""" session_id = AgentSessionId(name="TestAgent", key="test-key") - thread = DurableAgentThread(session_id=session_id) + session = DurableAgentSession(durable_session_id=session_id) - assert thread.session_id is not None - assert thread.session_id == session_id + assert session.durable_session_id is not None + assert session.durable_session_id == session_id - def test_init_without_session_id(self) -> None: - """Test DurableAgentThread initialization without session ID.""" - thread = DurableAgentThread() + def test_init_without_durable_session_id(self) -> None: + """Test DurableAgentSession initialization without durable session ID.""" + session = DurableAgentSession() - assert thread.session_id is None + assert session.durable_session_id is None - def test_session_id_setter(self) -> None: - """Test setting a session ID to an existing thread.""" - thread = DurableAgentThread() - assert thread.session_id is None + def test_durable_session_id_setter(self) -> None: + """Test setting a durable session ID to an existing session.""" + session = DurableAgentSession() + assert session.durable_session_id is None session_id = AgentSessionId(name="TestAgent", key="test-key") - thread.session_id = session_id + session.durable_session_id = session_id - assert thread.session_id is not None - assert thread.session_id == session_id - assert thread.session_id.name == "TestAgent" + assert session.durable_session_id is not None + assert session.durable_session_id == session_id + assert session.durable_session_id.name == "TestAgent" def test_from_session_id(self) -> None: - """Test creating DurableAgentThread from session ID.""" + """Test creating DurableAgentSession from session ID.""" session_id = AgentSessionId(name="TestAgent", key="test-key") - thread = DurableAgentThread.from_session_id(session_id) + session = DurableAgentSession.from_session_id(session_id) - assert isinstance(thread, DurableAgentThread) - assert thread.session_id is not None - assert thread.session_id == session_id - assert thread.session_id.name == "TestAgent" - assert thread.session_id.key == "test-key" + assert isinstance(session, DurableAgentSession) + assert session.durable_session_id is not None + assert session.durable_session_id == session_id + assert session.durable_session_id.name == "TestAgent" + assert session.durable_session_id.key == "test-key" - def test_from_session_id_with_service_thread_id(self) -> None: - """Test creating DurableAgentThread with service thread ID.""" + def test_from_session_id_with_service_session_id(self) -> None: + """Test creating DurableAgentSession with service session ID.""" session_id = AgentSessionId(name="TestAgent", key="test-key") - thread = DurableAgentThread.from_session_id(session_id, service_thread_id="service-123") + session = DurableAgentSession.from_session_id(session_id, service_session_id="service-123") - assert thread.session_id is not None - assert thread.session_id == session_id - assert thread.service_thread_id == "service-123" + assert session.durable_session_id is not None + assert session.durable_session_id == session_id + assert session.service_session_id == "service-123" - async def test_serialize_with_session_id(self) -> None: - """Test serialization includes session ID.""" + def test_to_dict_with_durable_session_id(self) -> None: + """Test serialization includes durable session ID.""" session_id = AgentSessionId(name="TestAgent", key="test-key") - thread = DurableAgentThread(session_id=session_id) + session = DurableAgentSession(durable_session_id=session_id) - serialized = await thread.serialize() + serialized = session.to_dict() assert isinstance(serialized, dict) assert "durable_session_id" in serialized assert serialized["durable_session_id"] == "@TestAgent@test-key" - async def test_serialize_without_session_id(self) -> None: - """Test serialization without session ID.""" - thread = DurableAgentThread() + def test_to_dict_without_durable_session_id(self) -> None: + """Test serialization without durable session ID.""" + session = DurableAgentSession() - serialized = await thread.serialize() + serialized = session.to_dict() assert isinstance(serialized, dict) assert "durable_session_id" not in serialized - async def test_deserialize_with_session_id(self) -> None: - """Test deserialization restores session ID.""" + def test_from_dict_with_durable_session_id(self) -> None: + """Test deserialization restores durable session ID.""" serialized = { - "service_thread_id": "thread-123", + "type": "session", + "session_id": "session-123", + "service_session_id": "service-123", + "state": {}, "durable_session_id": "@TestAgent@test-key", } - thread = await DurableAgentThread.deserialize(serialized) + session = DurableAgentSession.from_dict(serialized) - assert isinstance(thread, DurableAgentThread) - assert thread.session_id is not None - assert thread.session_id.name == "TestAgent" - assert thread.session_id.key == "test-key" - assert thread.service_thread_id == "thread-123" + assert isinstance(session, DurableAgentSession) + assert session.durable_session_id is not None + assert session.durable_session_id.name == "TestAgent" + assert session.durable_session_id.key == "test-key" + assert session.service_session_id == "service-123" - async def test_deserialize_without_session_id(self) -> None: - """Test deserialization without session ID.""" + def test_from_dict_without_durable_session_id(self) -> None: + """Test deserialization without durable session ID.""" serialized = { - "service_thread_id": "thread-456", + "type": "session", + "session_id": "session-456", + "service_session_id": "service-456", + "state": {}, } - thread = await DurableAgentThread.deserialize(serialized) + session = DurableAgentSession.from_dict(serialized) - assert isinstance(thread, DurableAgentThread) - assert thread.session_id is None - assert thread.service_thread_id == "thread-456" + assert isinstance(session, DurableAgentSession) + assert session.durable_session_id is None + assert session.session_id == "session-456" - async def test_round_trip_serialization(self) -> None: - """Test round-trip serialization preserves session ID.""" + def test_round_trip_serialization(self) -> None: + """Test round-trip serialization preserves durable session ID.""" session_id = AgentSessionId(name="TestAgent", key="test-key-789") - original = DurableAgentThread(session_id=session_id) + original = DurableAgentSession(durable_session_id=session_id) - serialized = await original.serialize() - restored = await DurableAgentThread.deserialize(serialized) + serialized = original.to_dict() + restored = DurableAgentSession.from_dict(serialized) - assert isinstance(restored, DurableAgentThread) - assert restored.session_id is not None - assert restored.session_id.name == session_id.name - assert restored.session_id.key == session_id.key + assert isinstance(restored, DurableAgentSession) + assert restored.durable_session_id is not None + assert restored.durable_session_id.name == session_id.name + assert restored.durable_session_id.key == session_id.key - async def test_deserialize_invalid_session_id_type(self) -> None: - """Test deserialization with invalid session ID type raises error.""" + def test_from_dict_invalid_durable_session_id_type(self) -> None: + """Test deserialization with invalid durable session ID type raises error.""" serialized = { - "service_thread_id": "thread-123", + "type": "session", + "session_id": "session-123", + "state": {}, "durable_session_id": 12345, # Invalid type } with pytest.raises(ValueError, match="durable_session_id must be a string"): - await DurableAgentThread.deserialize(serialized) + DurableAgentSession.from_dict(serialized) -class TestAgentThreadCompatibility: - """Test suite for compatibility between AgentThread and DurableAgentThread.""" +class TestAgentSessionCompatibility: + """Test suite for compatibility between AgentSession and DurableAgentSession.""" - async def test_agent_thread_serialize(self) -> None: - """Test that base AgentThread can be serialized.""" - thread = AgentThread() + def test_agent_session_to_dict(self) -> None: + """Test that base AgentSession can be serialized.""" + session = AgentSession() - serialized = await thread.serialize() + serialized = session.to_dict() assert isinstance(serialized, dict) - assert "service_thread_id" in serialized + assert "session_id" in serialized - async def test_agent_thread_deserialize(self) -> None: - """Test that base AgentThread can be deserialized.""" - thread = AgentThread() - serialized = await thread.serialize() + def test_agent_session_from_dict(self) -> None: + """Test that base AgentSession can be deserialized.""" + session = AgentSession() + serialized = session.to_dict() - restored = await AgentThread.deserialize(serialized) + restored = AgentSession.from_dict(serialized) - assert isinstance(restored, AgentThread) - assert restored.service_thread_id == thread.service_thread_id + assert isinstance(restored, AgentSession) + assert restored.session_id == session.session_id - async def test_durable_thread_is_agent_thread(self) -> None: - """Test that DurableAgentThread is an AgentThread.""" - thread = DurableAgentThread() + def test_durable_session_is_agent_session(self) -> None: + """Test that DurableAgentSession is an AgentSession.""" + session = DurableAgentSession() - assert isinstance(thread, AgentThread) - assert isinstance(thread, DurableAgentThread) + assert isinstance(session, AgentSession) + assert isinstance(session, DurableAgentSession) class TestModelIntegration: @@ -281,19 +289,19 @@ def test_session_id_string_format(self) -> None: assert session_id_str.startswith("@AgentEntity@") - async def test_thread_with_session_preserves_on_serialization(self) -> None: - """Test that thread with session ID preserves it through serialization.""" + def test_session_with_durable_id_preserves_on_serialization(self) -> None: + """Test that session with durable session ID preserves it through serialization.""" session_id = AgentSessionId(name="TestAgent", key="preserved-key") - thread = DurableAgentThread.from_session_id(session_id) + session = DurableAgentSession.from_session_id(session_id) # Serialize and deserialize - serialized = await thread.serialize() - restored = await DurableAgentThread.deserialize(serialized) + serialized = session.to_dict() + restored = DurableAgentSession.from_dict(serialized) - # Session ID should be preserved - assert restored.session_id is not None - assert restored.session_id.name == "TestAgent" - assert restored.session_id.key == "preserved-key" + # Durable session ID should be preserved + assert restored.durable_session_id is not None + assert restored.durable_session_id.name == "TestAgent" + assert restored.durable_session_id.key == "preserved-key" if __name__ == "__main__": diff --git a/python/packages/durabletask/tests/test_client.py b/python/packages/durabletask/tests/test_client.py index 7486352d17..0acdfb2f9c 100644 --- a/python/packages/durabletask/tests/test_client.py +++ b/python/packages/durabletask/tests/test_client.py @@ -11,7 +11,7 @@ import pytest from agent_framework import SupportsAgentRun -from agent_framework_durabletask import DurableAgentThread, DurableAIAgentClient +from agent_framework_durabletask import DurableAgentSession, DurableAIAgentClient from agent_framework_durabletask._constants import DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS from agent_framework_durabletask._shim import DurableAIAgent @@ -80,22 +80,22 @@ def test_client_agent_has_working_run_method(self, agent_client: DurableAIAgentC assert hasattr(agent, "run") assert callable(agent.run) - def test_client_agent_can_create_threads(self, agent_client: DurableAIAgentClient) -> None: - """Verify agent from client can create DurableAgentThread instances.""" + def test_client_agent_can_create_sessions(self, agent_client: DurableAIAgentClient) -> None: + """Verify agent from client can create DurableAgentSession instances.""" agent = agent_client.get_agent("assistant") - thread = agent.get_new_thread() + session = agent.create_session() - assert isinstance(thread, DurableAgentThread) + assert isinstance(session, DurableAgentSession) - def test_client_agent_thread_with_parameters(self, agent_client: DurableAIAgentClient) -> None: - """Verify agent can create threads with custom parameters.""" + def test_client_agent_session_with_parameters(self, agent_client: DurableAIAgentClient) -> None: + """Verify agent can create sessions with custom parameters.""" agent = agent_client.get_agent("assistant") - thread = agent.get_new_thread(service_thread_id="client-session-123") + session = agent.create_session(service_session_id="client-session-123") - assert isinstance(thread, DurableAgentThread) - assert thread.service_thread_id == "client-session-123" + assert isinstance(session, DurableAgentSession) + assert session.service_session_id == "client-session-123" class TestDurableAIAgentClientPollingConfiguration: diff --git a/python/packages/durabletask/tests/test_executors.py b/python/packages/durabletask/tests/test_executors.py index 802007541f..36aa8d46b2 100644 --- a/python/packages/durabletask/tests/test_executors.py +++ b/python/packages/durabletask/tests/test_executors.py @@ -16,7 +16,7 @@ from durabletask.task import Task from pydantic import BaseModel -from agent_framework_durabletask import DurableAgentThread +from agent_framework_durabletask import DurableAgentSession from agent_framework_durabletask._constants import DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS from agent_framework_durabletask._executors import ( ClientAgentExecutor, @@ -106,42 +106,42 @@ def _configure(exception: Exception) -> Mock: return _configure -class TestExecutorThreadCreation: - """Test that executors properly create DurableAgentThread with parameters.""" +class TestExecutorSessionCreation: + """Test that executors properly create DurableAgentSession with parameters.""" - def test_client_executor_creates_durable_thread(self, mock_client: Mock) -> None: - """Verify ClientAgentExecutor creates DurableAgentThread instances.""" + def test_client_executor_creates_durable_session(self, mock_client: Mock) -> None: + """Verify ClientAgentExecutor creates DurableAgentSession instances.""" executor = ClientAgentExecutor(mock_client) - thread = executor.get_new_thread("test_agent") + session = executor.get_new_session("test_agent") - assert isinstance(thread, DurableAgentThread) + assert isinstance(session, DurableAgentSession) - def test_client_executor_forwards_kwargs_to_thread(self, mock_client: Mock) -> None: - """Verify ClientAgentExecutor forwards kwargs to DurableAgentThread creation.""" + def test_client_executor_forwards_kwargs_to_session(self, mock_client: Mock) -> None: + """Verify ClientAgentExecutor forwards kwargs to DurableAgentSession creation.""" executor = ClientAgentExecutor(mock_client) - thread = executor.get_new_thread("test_agent", service_thread_id="client-123") + session = executor.get_new_session("test_agent", service_session_id="client-123") - assert isinstance(thread, DurableAgentThread) - assert thread.service_thread_id == "client-123" + assert isinstance(session, DurableAgentSession) + assert session.service_session_id == "client-123" - def test_orchestration_executor_creates_durable_thread( + def test_orchestration_executor_creates_durable_session( self, orchestration_executor: OrchestrationAgentExecutor ) -> None: - """Verify OrchestrationAgentExecutor creates DurableAgentThread instances.""" - thread = orchestration_executor.get_new_thread("test_agent") + """Verify OrchestrationAgentExecutor creates DurableAgentSession instances.""" + session = orchestration_executor.get_new_session("test_agent") - assert isinstance(thread, DurableAgentThread) + assert isinstance(session, DurableAgentSession) - def test_orchestration_executor_forwards_kwargs_to_thread( + def test_orchestration_executor_forwards_kwargs_to_session( self, orchestration_executor: OrchestrationAgentExecutor ) -> None: - """Verify OrchestrationAgentExecutor forwards kwargs to DurableAgentThread creation.""" - thread = orchestration_executor.get_new_thread("test_agent", service_thread_id="orch-456") + """Verify OrchestrationAgentExecutor forwards kwargs to DurableAgentSession creation.""" + session = orchestration_executor.get_new_session("test_agent", service_session_id="orch-456") - assert isinstance(thread, DurableAgentThread) - assert thread.service_thread_id == "orch-456" + assert isinstance(session, DurableAgentSession) + assert session.service_session_id == "orch-456" class TestClientAgentExecutorRun: @@ -353,18 +353,18 @@ def test_orchestration_executor_calls_entity_with_correct_parameters( # Verify request dict assert request_dict_arg == sample_run_request.to_dict() - def test_orchestration_executor_uses_thread_session_id( + def test_orchestration_executor_uses_session_durable_id( self, mock_orchestration_context: Mock, orchestration_executor: OrchestrationAgentExecutor, sample_run_request: RunRequest, ) -> None: - """Verify executor uses thread's session ID when provided.""" - # Create thread with specific session ID + """Verify executor uses session's durable session ID when provided.""" + # Create session with specific durable session ID session_id = AgentSessionId(name="test_agent", key="specific-key-123") - thread = DurableAgentThread.from_session_id(session_id) + session = DurableAgentSession.from_session_id(session_id) - result = orchestration_executor.run_durable_agent("test_agent", sample_run_request, thread=thread) + result = orchestration_executor.run_durable_agent("test_agent", sample_run_request, session=session) # Verify call_entity was called with the specific key call_args = mock_orchestration_context.call_entity.call_args diff --git a/python/packages/durabletask/tests/test_orchestration_context.py b/python/packages/durabletask/tests/test_orchestration_context.py index 073f0e1642..033c274c88 100644 --- a/python/packages/durabletask/tests/test_orchestration_context.py +++ b/python/packages/durabletask/tests/test_orchestration_context.py @@ -11,7 +11,7 @@ import pytest from agent_framework import SupportsAgentRun -from agent_framework_durabletask import DurableAgentThread +from agent_framework_durabletask import DurableAgentSession from agent_framework_durabletask._orchestration_context import DurableAIAgentOrchestrationContext from agent_framework_durabletask._shim import DurableAIAgent @@ -74,24 +74,24 @@ def test_orchestration_agent_has_working_run_method( assert hasattr(agent, "run") assert callable(agent.run) - def test_orchestration_agent_can_create_threads(self, agent_context: DurableAIAgentOrchestrationContext) -> None: - """Verify agent from context can create DurableAgentThread instances.""" + def test_orchestration_agent_can_create_sessions(self, agent_context: DurableAIAgentOrchestrationContext) -> None: + """Verify agent from context can create DurableAgentSession instances.""" agent = agent_context.get_agent("assistant") - thread = agent.get_new_thread() + session = agent.create_session() - assert isinstance(thread, DurableAgentThread) + assert isinstance(session, DurableAgentSession) - def test_orchestration_agent_thread_with_parameters( + def test_orchestration_agent_session_with_parameters( self, agent_context: DurableAIAgentOrchestrationContext ) -> None: - """Verify agent can create threads with custom parameters.""" + """Verify agent can create sessions with custom parameters.""" agent = agent_context.get_agent("assistant") - thread = agent.get_new_thread(service_thread_id="orch-session-456") + session = agent.create_session(service_session_id="orch-session-456") - assert isinstance(thread, DurableAgentThread) - assert thread.service_thread_id == "orch-session-456" + assert isinstance(session, DurableAgentSession) + assert session.service_session_id == "orch-session-456" if __name__ == "__main__": diff --git a/python/packages/durabletask/tests/test_shim.py b/python/packages/durabletask/tests/test_shim.py index 9f2fefc406..423f587871 100644 --- a/python/packages/durabletask/tests/test_shim.py +++ b/python/packages/durabletask/tests/test_shim.py @@ -13,7 +13,7 @@ from agent_framework import Message, SupportsAgentRun from pydantic import BaseModel -from agent_framework_durabletask import DurableAgentThread +from agent_framework_durabletask import DurableAgentSession from agent_framework_durabletask._executors import DurableAgentExecutor from agent_framework_durabletask._models import RunRequest from agent_framework_durabletask._shim import DurableAgentProvider, DurableAIAgent @@ -30,7 +30,7 @@ def mock_executor() -> Mock: """Create a mock executor for testing.""" mock = Mock(spec=DurableAgentExecutor) mock.run_durable_agent = Mock(return_value=None) - mock.get_new_thread = Mock(return_value=DurableAgentThread()) + mock.get_new_session = Mock(return_value=DurableAgentSession()) # Mock get_run_request to create actual RunRequest objects def create_run_request( @@ -124,14 +124,14 @@ def test_run_handles_empty_list(self, test_agent: DurableAIAgent[Any], mock_exec class TestDurableAIAgentParameterFlow: """Test that parameters flow correctly through the shim to executor.""" - def test_run_forwards_thread_parameter(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: - """Verify run forwards thread parameter to executor.""" - thread = DurableAgentThread(service_thread_id="test-thread") - test_agent.run("message", thread=thread) + def test_run_forwards_session_parameter(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: + """Verify run forwards session parameter to executor.""" + session = DurableAgentSession(service_session_id="test-session") + test_agent.run("message", session=session) mock_executor.run_durable_agent.assert_called_once() _, kwargs = mock_executor.run_durable_agent.call_args - assert kwargs["thread"] == thread + assert kwargs["session"] == session def test_run_forwards_response_format(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: """Verify run forwards response_format parameter to executor.""" @@ -171,29 +171,29 @@ def test_agent_id_can_be_customized(self, mock_executor: Mock) -> None: assert agent.name == "my_agent" -class TestDurableAIAgentThreadManagement: - """Test thread creation and management.""" +class TestDurableAIAgentSessionManagement: + """Test session creation and management.""" - def test_get_new_thread_delegates_to_executor(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: - """Verify get_new_thread delegates to executor.""" - mock_thread = DurableAgentThread() - mock_executor.get_new_thread.return_value = mock_thread + def test_create_session_delegates_to_executor(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: + """Verify create_session delegates to executor.""" + mock_session = DurableAgentSession() + mock_executor.get_new_session.return_value = mock_session - thread = test_agent.get_new_thread() + session = test_agent.create_session() - mock_executor.get_new_thread.assert_called_once_with("test_agent") - assert thread == mock_thread + mock_executor.get_new_session.assert_called_once_with("test_agent") + assert session == mock_session - def test_get_new_thread_forwards_kwargs(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: - """Verify get_new_thread forwards kwargs to executor.""" - mock_thread = DurableAgentThread(service_thread_id="thread-123") - mock_executor.get_new_thread.return_value = mock_thread + def test_create_session_forwards_kwargs(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: + """Verify create_session forwards kwargs to executor.""" + mock_session = DurableAgentSession(service_session_id="session-123") + mock_executor.get_new_session.return_value = mock_session - test_agent.get_new_thread(service_thread_id="thread-123") + test_agent.create_session(service_session_id="session-123") - mock_executor.get_new_thread.assert_called_once() - _, kwargs = mock_executor.get_new_thread.call_args - assert kwargs["service_thread_id"] == "thread-123" + mock_executor.get_new_session.assert_called_once() + _, kwargs = mock_executor.get_new_session.call_args + assert kwargs["service_session_id"] == "session-123" class TestDurableAgentProviderInterface: diff --git a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py index f8e60c9f3e..42de197014 100644 --- a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py +++ b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py @@ -13,10 +13,10 @@ AgentMiddlewareTypes, AgentResponse, AgentResponseUpdate, - AgentThread, + AgentSession, BaseAgent, + BaseContextProvider, Content, - ContextProvider, Message, ResponseStream, normalize_messages, @@ -149,7 +149,7 @@ def __init__( id: str | None = None, name: str | None = None, description: str | None = None, - context_provider: ContextProvider | None = None, + context_providers: Sequence[BaseContextProvider] | None = None, middleware: Sequence[AgentMiddlewareTypes] | None = None, tools: FunctionTool | Callable[..., Any] @@ -171,7 +171,7 @@ def __init__( id: ID of the GitHubCopilotAgent. name: Name of the GitHubCopilotAgent. description: Description of the GitHubCopilotAgent. - context_provider: Context Provider, to be used by the agent. + context_providers: Context Providers, to be used by the agent. middleware: Agent middleware used by the agent. tools: Tools to use for the agent. Can be functions or tool definition dicts. These are converted to Copilot SDK tools internally. @@ -187,7 +187,7 @@ def __init__( id=id, name=name, description=description, - context_provider=context_provider, + context_providers=context_providers, middleware=list(middleware) if middleware else None, ) @@ -280,7 +280,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: Literal[False] = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, options: OptionsT | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse]: ... @@ -291,7 +291,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: Literal[True], - thread: AgentThread | None = None, + session: AgentSession | None = None, options: OptionsT | None = None, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: ... @@ -301,7 +301,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, options: OptionsT | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: @@ -316,7 +316,7 @@ def run( Keyword Args: stream: Whether to stream the response. Defaults to False. - thread: The conversation thread associated with the message(s). + session: The conversation session associated with the message(s). options: Runtime options (model, timeout, etc.). kwargs: Additional keyword arguments. @@ -333,16 +333,16 @@ def _finalize(updates: Sequence[AgentResponseUpdate]) -> AgentResponse: return AgentResponse.from_updates(updates) return ResponseStream( - self._stream_updates(messages=messages, thread=thread, options=options, **kwargs), + self._stream_updates(messages=messages, session=session, options=options, **kwargs), finalizer=_finalize, ) - return self._run_impl(messages=messages, thread=thread, options=options, **kwargs) + return self._run_impl(messages=messages, session=session, options=options, **kwargs) async def _run_impl( self, messages: str | Message | Sequence[str | Message] | None = None, *, - thread: AgentThread | None = None, + session: AgentSession | None = None, options: OptionsT | None = None, **kwargs: Any, ) -> AgentResponse: @@ -350,18 +350,18 @@ async def _run_impl( if not self._started: await self.start() - if not thread: - thread = self.get_new_thread() + if not session: + session = self.create_session() opts: dict[str, Any] = dict(options) if options else {} timeout = opts.pop("timeout", None) or self._settings["timeout"] or DEFAULT_TIMEOUT_SECONDS - session = await self._get_or_create_session(thread, streaming=False, runtime_options=opts) + copilot_session = await self._get_or_create_session(session, streaming=False, runtime_options=opts) input_messages = normalize_messages(messages) prompt = "\n".join([message.text for message in input_messages]) try: - response_event = await session.send_and_wait({"prompt": prompt}, timeout=timeout) + response_event = await copilot_session.send_and_wait({"prompt": prompt}, timeout=timeout) except Exception as ex: raise ServiceException(f"GitHub Copilot request failed: {ex}") from ex @@ -390,7 +390,7 @@ async def _stream_updates( self, messages: str | Message | Sequence[str | Message] | None = None, *, - thread: AgentThread | None = None, + session: AgentSession | None = None, options: OptionsT | None = None, **kwargs: Any, ) -> AsyncIterable[AgentResponseUpdate]: @@ -400,7 +400,7 @@ async def _stream_updates( messages: The message(s) to send to the agent. Keyword Args: - thread: The conversation thread associated with the message(s). + session: The conversation session associated with the message(s). options: Runtime options (model, timeout, etc.). kwargs: Additional keyword arguments. @@ -413,12 +413,12 @@ async def _stream_updates( if not self._started: await self.start() - if not thread: - thread = self.get_new_thread() + if not session: + session = self.create_session() opts: dict[str, Any] = dict(options) if options else {} - session = await self._get_or_create_session(thread, streaming=True, runtime_options=opts) + copilot_session = await self._get_or_create_session(session, streaming=True, runtime_options=opts) input_messages = normalize_messages(messages) prompt = "\n".join([message.text for message in input_messages]) @@ -441,10 +441,10 @@ def event_handler(event: SessionEvent) -> None: error_msg = event.data.message or "Unknown error" queue.put_nowait(ServiceException(f"GitHub Copilot session error: {error_msg}")) - unsubscribe = session.on(event_handler) + unsubscribe = copilot_session.on(event_handler) try: - await session.send({"prompt": prompt}) + await copilot_session.send({"prompt": prompt}) while (item := await queue.get()) is not None: if isinstance(item, Exception): @@ -530,14 +530,14 @@ async def handler(invocation: ToolInvocation) -> ToolResult: async def _get_or_create_session( self, - thread: AgentThread, + agent_session: AgentSession, streaming: bool = False, runtime_options: dict[str, Any] | None = None, ) -> CopilotSession: - """Get an existing session or create a new one for the thread. + """Get an existing session or create a new one for the session. Args: - thread: The conversation thread. + agent_session: The conversation session. streaming: Whether to enable streaming for the session. runtime_options: Runtime options from run that take precedence. @@ -551,11 +551,11 @@ async def _get_or_create_session( raise ServiceException("GitHub Copilot client not initialized. Call start() first.") try: - if thread.service_thread_id: - return await self._resume_session(thread.service_thread_id, streaming) + if agent_session.service_session_id: + return await self._resume_session(agent_session.service_session_id, streaming) session = await self._create_session(streaming, runtime_options) - thread.service_thread_id = session.session_id + agent_session.service_session_id = session.session_id return session except Exception as ex: raise ServiceException(f"Failed to create GitHub Copilot session: {ex}") from ex diff --git a/python/packages/mem0/agent_framework_mem0/__init__.py b/python/packages/mem0/agent_framework_mem0/__init__.py index b43f1dba2c..cb6f75f8a5 100644 --- a/python/packages/mem0/agent_framework_mem0/__init__.py +++ b/python/packages/mem0/agent_framework_mem0/__init__.py @@ -8,8 +8,7 @@ if os.environ.get("MEM0_TELEMETRY") is None: os.environ["MEM0_TELEMETRY"] = "false" -from ._context_provider import _Mem0ContextProvider -from ._provider import Mem0Provider +from ._context_provider import Mem0ContextProvider try: __version__ = importlib.metadata.version(__name__) @@ -17,7 +16,6 @@ __version__ = "0.0.0" # Fallback for development mode __all__ = [ - "Mem0Provider", - "_Mem0ContextProvider", + "Mem0ContextProvider", "__version__", ] diff --git a/python/packages/mem0/agent_framework_mem0/_context_provider.py b/python/packages/mem0/agent_framework_mem0/_context_provider.py index 6a09887b72..c2d10d42cb 100644 --- a/python/packages/mem0/agent_framework_mem0/_context_provider.py +++ b/python/packages/mem0/agent_framework_mem0/_context_provider.py @@ -2,9 +2,8 @@ """New-pattern Mem0 context provider using BaseContextProvider. -This module provides ``_Mem0ContextProvider``, a side-by-side implementation of -:class:`Mem0Provider` built on the new :class:`BaseContextProvider` hooks pattern. -It will be renamed to ``Mem0ContextProvider`` in PR2 when the old class is removed. +This module provides ``Mem0ContextProvider``, built on the new +:class:`BaseContextProvider` hooks pattern. """ from __future__ import annotations @@ -35,17 +34,11 @@ class _MemorySearchResponse_v1_1(TypedDict): _MemorySearchResponse_v2 = list[dict[str, Any]] -class _Mem0ContextProvider(BaseContextProvider): +class Mem0ContextProvider(BaseContextProvider): """Mem0 context provider using the new BaseContextProvider hooks pattern. Integrates Mem0 for persistent semantic memory, searching and storing - memories via the Mem0 API. This is the new-pattern equivalent of - :class:`Mem0Provider`. - - Note: - This class uses a temporary ``_`` prefix to coexist with the existing - :class:`Mem0Provider`. It will be renamed to ``Mem0ContextProvider`` - in PR2. + memories via the Mem0 API. """ DEFAULT_CONTEXT_PROMPT = "## Memories\nConsider the following memories when answering user questions:" @@ -190,4 +183,4 @@ def _build_filters(self, *, session_id: str | None = None) -> dict[str, Any]: return filters -__all__ = ["_Mem0ContextProvider"] +__all__ = ["Mem0ContextProvider"] diff --git a/python/packages/mem0/tests/test_mem0_new_context_provider.py b/python/packages/mem0/tests/test_mem0_new_context_provider.py index a56e427e68..96a70c2beb 100644 --- a/python/packages/mem0/tests/test_mem0_new_context_provider.py +++ b/python/packages/mem0/tests/test_mem0_new_context_provider.py @@ -10,7 +10,7 @@ from agent_framework._sessions import AgentSession, SessionContext from agent_framework.exceptions import ServiceInitializationError -from agent_framework_mem0._context_provider import _Mem0ContextProvider +from agent_framework_mem0._context_provider import Mem0ContextProvider @pytest.fixture @@ -30,10 +30,10 @@ def mock_mem0_client() -> AsyncMock: class TestInit: - """Test _Mem0ContextProvider initialization.""" + """Test Mem0ContextProvider initialization.""" def test_init_with_all_params(self, mock_mem0_client: AsyncMock) -> None: - provider = _Mem0ContextProvider( + provider = Mem0ContextProvider( source_id="mem0", mem0_client=mock_mem0_client, api_key="key-123", @@ -52,8 +52,8 @@ def test_init_with_all_params(self, mock_mem0_client: AsyncMock) -> None: assert provider._should_close_client is False def test_init_default_context_prompt(self, mock_mem0_client: AsyncMock) -> None: - provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") - assert provider.context_prompt == _Mem0ContextProvider.DEFAULT_CONTEXT_PROMPT + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + assert provider.context_prompt == Mem0ContextProvider.DEFAULT_CONTEXT_PROMPT def test_init_auto_creates_client_when_none(self) -> None: """When no client is provided, a default AsyncMemoryClient is created and flagged for closing.""" @@ -61,12 +61,12 @@ def test_init_auto_creates_client_when_none(self) -> None: patch("mem0.client.main.AsyncMemoryClient.__init__", return_value=None) as mock_init, patch("mem0.client.main.AsyncMemoryClient._validate_api_key", return_value=None), ): - provider = _Mem0ContextProvider(source_id="mem0", api_key="test-key", user_id="u1") + provider = Mem0ContextProvider(source_id="mem0", api_key="test-key", user_id="u1") mock_init.assert_called_once_with(api_key="test-key") assert provider._should_close_client is True def test_provided_client_not_flagged_for_close(self, mock_mem0_client: AsyncMock) -> None: - provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") assert provider._should_close_client is False @@ -82,7 +82,7 @@ async def test_memories_added_to_context(self, mock_mem0_client: AsyncMock) -> N {"memory": "User likes Python"}, {"memory": "User prefers dark mode"}, ] - provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") session = AgentSession(session_id="test-session") ctx = SessionContext(input_messages=[Message(role="user", text="Hello")], session_id="s1") @@ -98,7 +98,7 @@ async def test_memories_added_to_context(self, mock_mem0_client: AsyncMock) -> N async def test_empty_input_skips_search(self, mock_mem0_client: AsyncMock) -> None: """Empty input messages → no search performed.""" - provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") session = AgentSession(session_id="test-session") ctx = SessionContext(input_messages=[Message(role="user", text="")], session_id="s1") @@ -110,7 +110,7 @@ async def test_empty_input_skips_search(self, mock_mem0_client: AsyncMock) -> No async def test_empty_search_results_no_messages(self, mock_mem0_client: AsyncMock) -> None: """Empty search results → no messages added.""" mock_mem0_client.search.return_value = [] - provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") session = AgentSession(session_id="test-session") ctx = SessionContext(input_messages=[Message(role="user", text="test")], session_id="s1") @@ -120,7 +120,7 @@ async def test_empty_search_results_no_messages(self, mock_mem0_client: AsyncMoc async def test_validates_filters_before_search(self, mock_mem0_client: AsyncMock) -> None: """Raises ServiceInitializationError when no filters.""" - provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) session = AgentSession(session_id="test-session") ctx = SessionContext(input_messages=[Message(role="user", text="test")], session_id="s1") @@ -130,7 +130,7 @@ async def test_validates_filters_before_search(self, mock_mem0_client: AsyncMock async def test_v1_1_response_format(self, mock_mem0_client: AsyncMock) -> None: """Search response in v1.1 dict format with 'results' key.""" mock_mem0_client.search.return_value = {"results": [{"memory": "remembered fact"}]} - provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") session = AgentSession(session_id="test-session") ctx = SessionContext(input_messages=[Message(role="user", text="test")], session_id="s1") @@ -142,7 +142,7 @@ async def test_v1_1_response_format(self, mock_mem0_client: AsyncMock) -> None: async def test_search_query_combines_input_messages(self, mock_mem0_client: AsyncMock) -> None: """Multiple input messages are joined for the search query.""" mock_mem0_client.search.return_value = [] - provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") session = AgentSession(session_id="test-session") ctx = SessionContext( input_messages=[ @@ -166,7 +166,7 @@ class TestAfterRun: async def test_stores_input_and_response(self, mock_mem0_client: AsyncMock) -> None: """Stores input+response messages to mem0 via client.add.""" - provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") session = AgentSession(session_id="test-session") ctx = SessionContext(input_messages=[Message(role="user", text="question")], session_id="s1") ctx._response = AgentResponse(messages=[Message(role="assistant", text="answer")]) @@ -184,7 +184,7 @@ async def test_stores_input_and_response(self, mock_mem0_client: AsyncMock) -> N async def test_only_stores_user_assistant_system(self, mock_mem0_client: AsyncMock) -> None: """Only stores user/assistant/system messages with text.""" - provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") session = AgentSession(session_id="test-session") ctx = SessionContext( input_messages=[ @@ -204,7 +204,7 @@ async def test_only_stores_user_assistant_system(self, mock_mem0_client: AsyncMo async def test_skips_empty_messages(self, mock_mem0_client: AsyncMock) -> None: """Skips messages with empty text.""" - provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") session = AgentSession(session_id="test-session") ctx = SessionContext( input_messages=[ @@ -221,7 +221,7 @@ async def test_skips_empty_messages(self, mock_mem0_client: AsyncMock) -> None: async def test_uses_session_id_as_run_id(self, mock_mem0_client: AsyncMock) -> None: """Uses session_id as run_id.""" - provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") session = AgentSession(session_id="test-session") ctx = SessionContext(input_messages=[Message(role="user", text="hi")], session_id="my-session") ctx._response = AgentResponse(messages=[Message(role="assistant", text="hey")]) @@ -232,7 +232,7 @@ async def test_uses_session_id_as_run_id(self, mock_mem0_client: AsyncMock) -> N async def test_validates_filters(self, mock_mem0_client: AsyncMock) -> None: """Raises ServiceInitializationError when no filters.""" - provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) session = AgentSession(session_id="test-session") ctx = SessionContext(input_messages=[Message(role="user", text="hi")], session_id="s1") ctx._response = AgentResponse(messages=[Message(role="assistant", text="hey")]) @@ -242,7 +242,7 @@ async def test_validates_filters(self, mock_mem0_client: AsyncMock) -> None: async def test_stores_with_application_id_metadata(self, mock_mem0_client: AsyncMock) -> None: """application_id is passed in metadata.""" - provider = _Mem0ContextProvider( + provider = Mem0ContextProvider( source_id="mem0", mem0_client=mock_mem0_client, user_id="u1", application_id="app1" ) session = AgentSession(session_id="test-session") @@ -261,20 +261,20 @@ class TestValidateFilters: """Test _validate_filters method.""" def test_raises_when_no_filters(self, mock_mem0_client: AsyncMock) -> None: - provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) with pytest.raises(ServiceInitializationError, match="At least one of the filters"): provider._validate_filters() def test_passes_with_user_id(self, mock_mem0_client: AsyncMock) -> None: - provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") provider._validate_filters() # should not raise def test_passes_with_agent_id(self, mock_mem0_client: AsyncMock) -> None: - provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, agent_id="a1") + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, agent_id="a1") provider._validate_filters() def test_passes_with_application_id(self, mock_mem0_client: AsyncMock) -> None: - provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, application_id="app1") + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, application_id="app1") provider._validate_filters() @@ -285,11 +285,11 @@ class TestBuildFilters: """Test _build_filters method.""" def test_user_id_only(self, mock_mem0_client: AsyncMock) -> None: - provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") assert provider._build_filters() == {"user_id": "u1"} def test_all_params(self, mock_mem0_client: AsyncMock) -> None: - provider = _Mem0ContextProvider( + provider = Mem0ContextProvider( source_id="mem0", mem0_client=mock_mem0_client, user_id="u1", @@ -304,19 +304,19 @@ def test_all_params(self, mock_mem0_client: AsyncMock) -> None: } def test_excludes_none_values(self, mock_mem0_client: AsyncMock) -> None: - provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") filters = provider._build_filters() assert "agent_id" not in filters assert "run_id" not in filters assert "app_id" not in filters def test_session_id_mapped_to_run_id(self, mock_mem0_client: AsyncMock) -> None: - provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") filters = provider._build_filters(session_id="s99") assert filters["run_id"] == "s99" def test_empty_when_no_params(self, mock_mem0_client: AsyncMock) -> None: - provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) assert provider._build_filters() == {} @@ -327,26 +327,26 @@ class TestContextManager: """Test __aenter__/__aexit__ delegation.""" async def test_aenter_delegates_to_client(self, mock_mem0_client: AsyncMock) -> None: - provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") result = await provider.__aenter__() assert result is provider mock_mem0_client.__aenter__.assert_awaited_once() async def test_aexit_closes_auto_created_client(self, mock_mem0_client: AsyncMock) -> None: """Auto-created clients (_should_close_client=True) are closed on exit.""" - provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") provider._should_close_client = True await provider.__aexit__(None, None, None) mock_mem0_client.__aexit__.assert_awaited_once() async def test_aexit_does_not_close_provided_client(self, mock_mem0_client: AsyncMock) -> None: """Provided clients (_should_close_client=False) are NOT closed on exit.""" - provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") assert provider._should_close_client is False await provider.__aexit__(None, None, None) mock_mem0_client.__aexit__.assert_not_awaited() async def test_async_with_syntax(self, mock_mem0_client: AsyncMock) -> None: - provider = _Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") async with provider as p: assert p is provider diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_group_chat.py b/python/packages/orchestrations/agent_framework_orchestrations/_group_chat.py index 747d6efb10..857b1128b4 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_group_chat.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_group_chat.py @@ -29,10 +29,11 @@ from dataclasses import dataclass from typing import Any, ClassVar, cast -from agent_framework import Agent, AgentThread, Message, SupportsAgentRun +from agent_framework import Agent, AgentSession, Message, SupportsAgentRun from agent_framework._workflows._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse from agent_framework._workflows._agent_utils import resolve_agent_id from agent_framework._workflows._checkpoint import CheckpointStorage +from agent_framework._workflows._conversation_state import decode_chat_messages, encode_chat_messages from agent_framework._workflows._executor import Executor from agent_framework._workflows._workflow import Workflow from agent_framework._workflows._workflow_builder import WorkflowBuilder @@ -291,7 +292,7 @@ def __init__( max_rounds: int | None = None, termination_condition: TerminationCondition | None = None, retry_attempts: int | None = None, - thread: AgentThread | None = None, + session: AgentSession | None = None, ) -> None: """Initialize the GroupChatOrchestrator. @@ -302,7 +303,7 @@ def __init__( max_rounds: Optional limit on selection rounds to prevent infinite loops. termination_condition: Optional callable that halts the conversation when it returns True retry_attempts: Optional number of retry attempts for the agent in case of failure. - thread: Optional agent thread to use for the orchestrator agent. + session: Optional agent session to use for the orchestrator agent. """ super().__init__( resolve_agent_id(agent), @@ -313,7 +314,7 @@ def __init__( ) self._agent = agent self._retry_attempts = retry_attempts - self._thread = thread or agent.get_new_thread() + self._session = session or agent.create_session() # Cache for messages since last agent invocation # This is different from the full conversation history maintained by the base orchestrator self._cache: list[Message] = [] @@ -471,7 +472,7 @@ async def _invoke_agent_helper(conversation: list[Message]) -> AgentOrchestratio # Run the agent in non-streaming mode for simplicity agent_response = await self._agent.run( messages=conversation, - thread=self._thread, + session=self._session, options={"response_format": AgentOrchestrationOutput}, ) # Parse and validate the structured output @@ -546,9 +547,9 @@ async def _check_agent_terminate_and_yield( async def on_checkpoint_save(self) -> dict[str, Any]: """Capture current orchestrator state for checkpointing.""" state = await super().on_checkpoint_save() - state["cache"] = self._cache - serialized_thread = await self._thread.serialize() - state["thread"] = serialized_thread + state["cache"] = encode_chat_messages(self._cache) + serialized_session = self._session.to_dict() + state["session"] = serialized_session return state @@ -556,10 +557,10 @@ async def on_checkpoint_save(self) -> dict[str, Any]: async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: """Restore executor state from checkpoint.""" await super().on_checkpoint_restore(state) - self._cache = state.get("cache", []) - serialized_thread = state.get("thread") - if serialized_thread: - self._thread = await self._agent.deserialize_thread(serialized_thread) + self._cache = decode_chat_messages(state.get("cache", [])) + serialized_session = state.get("session") + if serialized_session: + self._session = AgentSession.from_dict(serialized_session) # endregion diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py index 367855e4c2..ce365724ef 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py @@ -38,7 +38,7 @@ from agent_framework import Agent, SupportsAgentRun from agent_framework._middleware import FunctionInvocationContext, FunctionMiddleware -from agent_framework._threads import AgentThread +from agent_framework._sessions import AgentSession from agent_framework._tools import FunctionTool, tool from agent_framework._types import AgentResponse, AgentResponseUpdate, Message from agent_framework._workflows._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse @@ -196,7 +196,7 @@ def __init__( agent: SupportsAgentRun, handoffs: Sequence[HandoffConfiguration], *, - agent_thread: AgentThread | None = None, + agent_thread: AgentSession | None = None, is_start_agent: bool = False, termination_condition: TerminationCondition | None = None, autonomous_mode: bool = False, @@ -208,7 +208,7 @@ def __init__( Args: agent: The agent to execute handoffs: Sequence of handoff configurations defining target agents - agent_thread: Optional AgentThread that manages the agent's execution context + agent_thread: Optional AgentSession that manages the agent's execution context is_start_agent: Whether this agent is the starting agent in the handoff workflow. There can only be one starting agent in a handoff workflow. termination_condition: Optional callable that determines when to terminate the workflow diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_magentic.py b/python/packages/orchestrations/agent_framework_orchestrations/_magentic.py index eec597cdda..17b927326b 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_magentic.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_magentic.py @@ -1338,8 +1338,8 @@ async def handle_magentic_reset(self, signal: MagenticResetSignal, ctx: Workflow # Request into related self._pending_agent_requests.clear() self._pending_responses_to_agent.clear() - # Reset threads - self._agent_thread = self._agent.get_new_thread() + # Reset sessions + self._agent_thread = self._agent.create_session() # endregion Magentic Agent Executor diff --git a/python/packages/purview/agent_framework_purview/_middleware.py b/python/packages/purview/agent_framework_purview/_middleware.py index 70f3b2892f..f0165704d8 100644 --- a/python/packages/purview/agent_framework_purview/_middleware.py +++ b/python/packages/purview/agent_framework_purview/_middleware.py @@ -48,12 +48,12 @@ def _get_agent_session_id(context: AgentContext) -> str | None: """Resolve a session/conversation id from the agent run context. Resolution order: - 1. thread.service_thread_id + 1. session.service_session_id 2. First message whose additional_properties contains 'conversation_id' 3. None: the downstream processor will generate a new UUID """ - if context.thread and context.thread.service_thread_id: - return context.thread.service_thread_id + if context.session and context.session.service_session_id: + return context.session.service_session_id for message in context.messages: conversation_id = message.additional_properties.get("conversation_id") diff --git a/python/packages/redis/agent_framework_redis/__init__.py b/python/packages/redis/agent_framework_redis/__init__.py index 9453401441..fefbe99e6d 100644 --- a/python/packages/redis/agent_framework_redis/__init__.py +++ b/python/packages/redis/agent_framework_redis/__init__.py @@ -1,10 +1,8 @@ # Copyright (c) Microsoft. All rights reserved. import importlib.metadata -from ._chat_message_store import RedisChatMessageStore -from ._context_provider import _RedisContextProvider -from ._history_provider import _RedisHistoryProvider -from ._provider import RedisProvider +from ._context_provider import RedisContextProvider +from ._history_provider import RedisHistoryProvider try: __version__ = importlib.metadata.version(__name__) @@ -12,9 +10,7 @@ __version__ = "0.0.0" # Fallback for development mode __all__ = [ - "RedisChatMessageStore", - "RedisProvider", - "_RedisContextProvider", - "_RedisHistoryProvider", + "RedisContextProvider", + "RedisHistoryProvider", "__version__", ] diff --git a/python/packages/redis/agent_framework_redis/_context_provider.py b/python/packages/redis/agent_framework_redis/_context_provider.py index f4e44a8677..e8e0dbd66e 100644 --- a/python/packages/redis/agent_framework_redis/_context_provider.py +++ b/python/packages/redis/agent_framework_redis/_context_provider.py @@ -2,9 +2,8 @@ """New-pattern Redis context provider using BaseContextProvider. -This module provides ``_RedisContextProvider``, a side-by-side implementation of -:class:`RedisProvider` built on the new :class:`BaseContextProvider` hooks pattern. -It will be renamed to ``RedisContextProvider`` in PR2 when the old class is removed. +This module provides ``RedisContextProvider``, built on the new +:class:`BaseContextProvider` hooks pattern. """ from __future__ import annotations @@ -43,17 +42,11 @@ from agent_framework._agents import SupportsAgentRun -class _RedisContextProvider(BaseContextProvider): +class RedisContextProvider(BaseContextProvider): """Redis context provider using the new BaseContextProvider hooks pattern. Stores context in Redis and retrieves scoped context via full-text or - optional hybrid vector search. This is the new-pattern equivalent of - :class:`RedisProvider`. - - Note: - This class uses a temporary ``_`` prefix to coexist with the existing - :class:`RedisProvider`. It will be renamed to ``RedisContextProvider`` - in PR2. + optional hybrid vector search. """ DEFAULT_CONTEXT_PROMPT = "## Memories\nConsider the following memories when answering user questions:" @@ -429,4 +422,4 @@ async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseExc """Async context manager exit.""" -__all__ = ["_RedisContextProvider"] +__all__ = ["RedisContextProvider"] diff --git a/python/packages/redis/agent_framework_redis/_history_provider.py b/python/packages/redis/agent_framework_redis/_history_provider.py index 54d1ec5f81..9cb39be202 100644 --- a/python/packages/redis/agent_framework_redis/_history_provider.py +++ b/python/packages/redis/agent_framework_redis/_history_provider.py @@ -2,9 +2,8 @@ """New-pattern Redis history provider using BaseHistoryProvider. -This module provides ``_RedisHistoryProvider``, a side-by-side implementation of -:class:`RedisMessageStore` built on the new :class:`BaseHistoryProvider` hooks pattern. -It will be renamed to ``RedisHistoryProvider`` in PR2 when the old class is removed. +This module provides ``RedisHistoryProvider``, built on the new +:class:`BaseHistoryProvider` hooks pattern. """ from __future__ import annotations @@ -18,17 +17,11 @@ from redis.credentials import CredentialProvider -class _RedisHistoryProvider(BaseHistoryProvider): +class RedisHistoryProvider(BaseHistoryProvider): """Redis-backed history provider using the new BaseHistoryProvider hooks pattern. Stores conversation history in Redis Lists, with each session isolated by a - unique Redis key. This is the new-pattern equivalent of - :class:`RedisMessageStore`. - - Note: - This class uses a temporary ``_`` prefix to coexist with the existing - :class:`RedisMessageStore`. It will be renamed to ``RedisHistoryProvider`` - in PR2. + unique Redis key. """ def __init__( @@ -181,4 +174,4 @@ async def aclose(self) -> None: await self._redis_client.aclose() # type: ignore[misc] -__all__ = ["_RedisHistoryProvider"] +__all__ = ["RedisHistoryProvider"] diff --git a/python/packages/redis/tests/test_new_providers.py b/python/packages/redis/tests/test_new_providers.py index 3540386873..9bd388517d 100644 --- a/python/packages/redis/tests/test_new_providers.py +++ b/python/packages/redis/tests/test_new_providers.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. -"""Tests for _RedisContextProvider and _RedisHistoryProvider.""" +"""Tests for RedisContextProvider and RedisHistoryProvider.""" from __future__ import annotations @@ -12,8 +12,8 @@ from agent_framework._sessions import AgentSession, SessionContext from agent_framework.exceptions import ServiceInitializationError -from agent_framework_redis._context_provider import _RedisContextProvider -from agent_framework_redis._history_provider import _RedisHistoryProvider +from agent_framework_redis._context_provider import RedisContextProvider +from agent_framework_redis._history_provider import RedisHistoryProvider # --------------------------------------------------------------------------- # Shared fixtures @@ -63,13 +63,13 @@ def mock_redis_client(): # =========================================================================== -# _RedisContextProvider tests +# RedisContextProvider tests # =========================================================================== class TestRedisContextProviderInit: def test_basic_construction(self, patch_index_from_dict: MagicMock): # noqa: ARG002 - provider = _RedisContextProvider(source_id="ctx", user_id="u1") + provider = RedisContextProvider(source_id="ctx", user_id="u1") assert provider.source_id == "ctx" assert provider.user_id == "u1" assert provider.redis_url == "redis://localhost:6379" @@ -77,7 +77,7 @@ def test_basic_construction(self, patch_index_from_dict: MagicMock): # noqa: AR assert provider.prefix == "context" def test_custom_params(self, patch_index_from_dict: MagicMock): # noqa: ARG002 - provider = _RedisContextProvider( + provider = RedisContextProvider( source_id="ctx", redis_url="redis://custom:6380", index_name="my_idx", @@ -95,31 +95,31 @@ def test_custom_params(self, patch_index_from_dict: MagicMock): # noqa: ARG002 assert provider.context_prompt == "Custom prompt" def test_default_context_prompt(self, patch_index_from_dict: MagicMock): # noqa: ARG002 - provider = _RedisContextProvider(source_id="ctx", user_id="u1") + provider = RedisContextProvider(source_id="ctx", user_id="u1") assert "Memories" in provider.context_prompt def test_invalid_vectorizer_raises(self, patch_index_from_dict: MagicMock): # noqa: ARG002 from agent_framework.exceptions import AgentException with pytest.raises(AgentException, match="not a valid type"): - _RedisContextProvider(source_id="ctx", user_id="u1", redis_vectorizer="bad") # type: ignore[arg-type] + RedisContextProvider(source_id="ctx", user_id="u1", redis_vectorizer="bad") # type: ignore[arg-type] class TestRedisContextProviderValidateFilters: def test_no_filters_raises(self, patch_index_from_dict: MagicMock): # noqa: ARG002 - provider = _RedisContextProvider(source_id="ctx") + provider = RedisContextProvider(source_id="ctx") with pytest.raises(ServiceInitializationError, match="(?i)at least one"): provider._validate_filters() def test_any_single_filter_ok(self, patch_index_from_dict: MagicMock): # noqa: ARG002 for kwargs in [{"user_id": "u"}, {"agent_id": "a"}, {"application_id": "app"}]: - provider = _RedisContextProvider(source_id="ctx", **kwargs) + provider = RedisContextProvider(source_id="ctx", **kwargs) provider._validate_filters() # should not raise class TestRedisContextProviderSchema: def test_schema_has_expected_fields(self, patch_index_from_dict: MagicMock): # noqa: ARG002 - provider = _RedisContextProvider(source_id="ctx", user_id="u1") + provider = RedisContextProvider(source_id="ctx", user_id="u1") schema = provider.schema_dict field_names = [f["name"] for f in schema["fields"]] for expected in ("role", "content", "conversation_id", "message_id", "application_id", "agent_id", "user_id"): @@ -128,7 +128,7 @@ def test_schema_has_expected_fields(self, patch_index_from_dict: MagicMock): # assert schema["index"]["prefix"] == "context" def test_schema_no_vector_without_vectorizer(self, patch_index_from_dict: MagicMock): # noqa: ARG002 - provider = _RedisContextProvider(source_id="ctx", user_id="u1") + provider = RedisContextProvider(source_id="ctx", user_id="u1") field_types = [f["type"] for f in provider.schema_dict["fields"]] assert "vector" not in field_types @@ -140,7 +140,7 @@ async def test_search_results_added_to_context( patch_index_from_dict: MagicMock, # noqa: ARG002 ): mock_index.query = AsyncMock(return_value=[{"content": "Memory A"}, {"content": "Memory B"}]) - provider = _RedisContextProvider(source_id="ctx", user_id="u1") + provider = RedisContextProvider(source_id="ctx", user_id="u1") session = AgentSession(session_id="test-session") ctx = SessionContext(input_messages=[Message(role="user", contents=["test query"])], session_id="s1") @@ -157,7 +157,7 @@ async def test_empty_input_no_search( mock_index: AsyncMock, patch_index_from_dict: MagicMock, # noqa: ARG002 ): - provider = _RedisContextProvider(source_id="ctx", user_id="u1") + provider = RedisContextProvider(source_id="ctx", user_id="u1") session = AgentSession(session_id="test-session") ctx = SessionContext(input_messages=[Message(role="user", contents=[" "])], session_id="s1") @@ -172,7 +172,7 @@ async def test_empty_results_no_messages( patch_index_from_dict: MagicMock, # noqa: ARG002 ): mock_index.query = AsyncMock(return_value=[]) - provider = _RedisContextProvider(source_id="ctx", user_id="u1") + provider = RedisContextProvider(source_id="ctx", user_id="u1") session = AgentSession(session_id="test-session") ctx = SessionContext(input_messages=[Message(role="user", contents=["hello"])], session_id="s1") @@ -187,7 +187,7 @@ async def test_stores_messages( mock_index: AsyncMock, patch_index_from_dict: MagicMock, # noqa: ARG002 ): - provider = _RedisContextProvider(source_id="ctx", user_id="u1") + provider = RedisContextProvider(source_id="ctx", user_id="u1") session = AgentSession(session_id="test-session") response = AgentResponse(messages=[Message(role="assistant", contents=["response text"])]) ctx = SessionContext(input_messages=[Message(role="user", contents=["user input"])], session_id="s1") @@ -206,7 +206,7 @@ async def test_skips_empty_conversations( mock_index: AsyncMock, patch_index_from_dict: MagicMock, # noqa: ARG002 ): - provider = _RedisContextProvider(source_id="ctx", user_id="u1") + provider = RedisContextProvider(source_id="ctx", user_id="u1") session = AgentSession(session_id="test-session") ctx = SessionContext(input_messages=[Message(role="user", contents=[" "])], session_id="s1") @@ -219,7 +219,7 @@ async def test_stores_partition_fields( mock_index: AsyncMock, patch_index_from_dict: MagicMock, # noqa: ARG002 ): - provider = _RedisContextProvider(source_id="ctx", application_id="app", agent_id="ag", user_id="u1") + provider = RedisContextProvider(source_id="ctx", application_id="app", agent_id="ag", user_id="u1") session = AgentSession(session_id="test-session") ctx = SessionContext(input_messages=[Message(role="user", contents=["hello"])], session_id="s1") @@ -235,13 +235,13 @@ async def test_stores_partition_fields( class TestRedisContextProviderContextManager: async def test_aenter_returns_self(self, patch_index_from_dict: MagicMock): # noqa: ARG002 - provider = _RedisContextProvider(source_id="ctx", user_id="u1") + provider = RedisContextProvider(source_id="ctx", user_id="u1") async with provider as p: assert p is provider # =========================================================================== -# _RedisHistoryProvider tests +# RedisHistoryProvider tests # =========================================================================== @@ -249,7 +249,7 @@ class TestRedisHistoryProviderInit: def test_basic_construction(self, mock_redis_client: MagicMock): with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url: mock_from_url.return_value = mock_redis_client - provider = _RedisHistoryProvider("memory", redis_url="redis://localhost:6379") + provider = RedisHistoryProvider("memory", redis_url="redis://localhost:6379") assert provider.source_id == "memory" assert provider.key_prefix == "chat_messages" @@ -261,7 +261,7 @@ def test_basic_construction(self, mock_redis_client: MagicMock): def test_custom_params(self, mock_redis_client: MagicMock): with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url: mock_from_url.return_value = mock_redis_client - provider = _RedisHistoryProvider( + provider = RedisHistoryProvider( "mem", redis_url="redis://localhost:6379", key_prefix="custom", @@ -279,12 +279,12 @@ def test_custom_params(self, mock_redis_client: MagicMock): def test_no_redis_url_or_credential_raises(self): with pytest.raises(ValueError, match="Either redis_url or credential_provider must be provided"): - _RedisHistoryProvider("mem") + RedisHistoryProvider("mem") def test_both_url_and_credential_raises(self): mock_cred = MagicMock() with pytest.raises(ValueError, match="mutually exclusive"): - _RedisHistoryProvider( + RedisHistoryProvider( "mem", redis_url="redis://localhost:6379", credential_provider=mock_cred, @@ -294,13 +294,13 @@ def test_both_url_and_credential_raises(self): def test_credential_provider_without_host_raises(self): mock_cred = MagicMock() with pytest.raises(ValueError, match="host is required"): - _RedisHistoryProvider("mem", credential_provider=mock_cred) + RedisHistoryProvider("mem", credential_provider=mock_cred) def test_credential_provider_with_host(self): mock_cred = MagicMock() with patch("agent_framework_redis._history_provider.redis.Redis") as mock_redis_cls: mock_redis_cls.return_value = MagicMock() - provider = _RedisHistoryProvider("mem", credential_provider=mock_cred, host="myhost") + provider = RedisHistoryProvider("mem", credential_provider=mock_cred, host="myhost") mock_redis_cls.assert_called_once_with( host="myhost", @@ -317,7 +317,7 @@ class TestRedisHistoryProviderRedisKey: def test_key_format(self, mock_redis_client: MagicMock): with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url: mock_from_url.return_value = mock_redis_client - provider = _RedisHistoryProvider("mem", redis_url="redis://localhost:6379", key_prefix="msgs") + provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379", key_prefix="msgs") assert provider._redis_key("session-123") == "msgs:session-123" assert provider._redis_key(None) == "msgs:default" @@ -331,7 +331,7 @@ async def test_returns_deserialized_messages(self, mock_redis_client: MagicMock) with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url: mock_from_url.return_value = mock_redis_client - provider = _RedisHistoryProvider("mem", redis_url="redis://localhost:6379") + provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379") messages = await provider.get_messages("s1") assert len(messages) == 2 @@ -345,7 +345,7 @@ async def test_empty_returns_empty(self, mock_redis_client: MagicMock): with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url: mock_from_url.return_value = mock_redis_client - provider = _RedisHistoryProvider("mem", redis_url="redis://localhost:6379") + provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379") messages = await provider.get_messages("s1") assert messages == [] @@ -355,7 +355,7 @@ class TestRedisHistoryProviderSaveMessages: async def test_saves_serialized_messages(self, mock_redis_client: MagicMock): with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url: mock_from_url.return_value = mock_redis_client - provider = _RedisHistoryProvider("mem", redis_url="redis://localhost:6379") + provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379") msgs = [Message(role="user", contents=["Hello"]), Message(role="assistant", contents=["Hi"])] await provider.save_messages("s1", msgs) @@ -367,7 +367,7 @@ async def test_saves_serialized_messages(self, mock_redis_client: MagicMock): async def test_empty_messages_noop(self, mock_redis_client: MagicMock): with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url: mock_from_url.return_value = mock_redis_client - provider = _RedisHistoryProvider("mem", redis_url="redis://localhost:6379") + provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379") await provider.save_messages("s1", []) mock_redis_client.pipeline.assert_not_called() @@ -377,7 +377,7 @@ async def test_max_messages_trimming(self, mock_redis_client: MagicMock): with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url: mock_from_url.return_value = mock_redis_client - provider = _RedisHistoryProvider("mem", redis_url="redis://localhost:6379", max_messages=10) + provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379", max_messages=10) await provider.save_messages("s1", [Message(role="user", contents=["msg"])]) @@ -388,7 +388,7 @@ async def test_no_trim_when_under_limit(self, mock_redis_client: MagicMock): with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url: mock_from_url.return_value = mock_redis_client - provider = _RedisHistoryProvider("mem", redis_url="redis://localhost:6379", max_messages=10) + provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379", max_messages=10) await provider.save_messages("s1", [Message(role="user", contents=["msg"])]) @@ -399,7 +399,7 @@ class TestRedisHistoryProviderClear: async def test_clear_calls_delete(self, mock_redis_client: MagicMock): with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url: mock_from_url.return_value = mock_redis_client - provider = _RedisHistoryProvider("mem", redis_url="redis://localhost:6379") + provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379") await provider.clear("session-1") mock_redis_client.delete.assert_called_once_with("chat_messages:session-1") @@ -414,7 +414,7 @@ async def test_before_run_loads_history(self, mock_redis_client: MagicMock): with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url: mock_from_url.return_value = mock_redis_client - provider = _RedisHistoryProvider("mem", redis_url="redis://localhost:6379") + provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379") session = AgentSession(session_id="test") ctx = SessionContext(input_messages=[Message(role="user", contents=["new msg"])], session_id="s1") @@ -428,7 +428,7 @@ async def test_before_run_loads_history(self, mock_redis_client: MagicMock): async def test_after_run_stores_input_and_response(self, mock_redis_client: MagicMock): with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url: mock_from_url.return_value = mock_redis_client - provider = _RedisHistoryProvider("mem", redis_url="redis://localhost:6379") + provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379") session = AgentSession(session_id="test") ctx = SessionContext(input_messages=[Message(role="user", contents=["hi"])], session_id="s1") @@ -443,7 +443,7 @@ async def test_after_run_stores_input_and_response(self, mock_redis_client: Magi async def test_after_run_skips_when_no_messages(self, mock_redis_client: MagicMock): with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url: mock_from_url.return_value = mock_redis_client - provider = _RedisHistoryProvider( + provider = RedisHistoryProvider( "mem", redis_url="redis://localhost:6379", store_inputs=False, store_outputs=False ) From b10e71a8c93f105c754ea052738fcf9b8fb5ea99 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Feb 2026 12:10:36 +0100 Subject: [PATCH 02/28] fix: update all tests for context provider pipeline, fix lazy-loaders, remove old test files --- python/packages/ag-ui/tests/ag_ui/conftest.py | 20 +- .../ag_ui/test_agent_wrapper_comprehensive.py | 38 +- .../tests/test_search_provider.py | 1018 ----------------- .../tests/test_azure_ai_agent_client.py | 26 +- .../claude/tests/test_claude_agent.py | 47 +- .../packages/core/agent_framework/_agents.py | 2 +- .../core/agent_framework/mem0/__init__.py | 2 +- .../core/agent_framework/redis/__init__.py | 2 +- .../azure/test_azure_assistants_client.py | 30 +- .../tests/azure/test_azure_chat_client.py | 30 +- .../azure/test_azure_responses_client.py | 22 +- python/packages/core/tests/core/conftest.py | 26 +- .../packages/core/tests/core/test_agents.py | 336 +++--- .../core/test_function_invocation_logic.py | 8 +- .../core/tests/core/test_middleware.py | 52 +- .../tests/core/test_middleware_with_agent.py | 35 +- .../core/tests/core/test_observability.py | 36 +- .../openai/test_openai_assistants_client.py | 30 +- .../tests/workflow/test_agent_executor.py | 113 +- .../test_agent_executor_tool_calls.py | 4 +- .../core/tests/workflow/test_agent_utils.py | 8 +- .../tests/workflow/test_full_conversation.py | 6 +- .../core/tests/workflow/test_workflow.py | 4 +- .../tests/workflow/test_workflow_agent.py | 102 +- .../tests/workflow/test_workflow_builder.py | 4 +- .../tests/workflow/test_workflow_kwargs.py | 4 +- python/packages/devui/tests/devui/conftest.py | 22 +- .../devui/tests/devui/test_discovery.py | 18 +- .../devui/tests/devui/test_execution.py | 10 +- .../tests/test_github_copilot_agent.py | 80 +- .../mem0/tests/test_mem0_context_provider.py | 521 ++++----- .../_handoff.py | 5 +- .../orchestrations/tests/test_group_chat.py | 8 +- .../orchestrations/tests/test_handoff.py | 20 +- .../orchestrations/tests/test_magentic.py | 10 +- .../tests/test_orchestration_request_info.py | 10 +- .../orchestrations/tests/test_sequential.py | 4 +- .../purview/tests/purview/test_middleware.py | 26 +- .../tests/test_redis_chat_message_store.py | 621 ---------- .../redis/tests/test_redis_provider.py | 425 ------- 40 files changed, 785 insertions(+), 3000 deletions(-) delete mode 100644 python/packages/azure-ai-search/tests/test_search_provider.py delete mode 100644 python/packages/redis/tests/test_redis_chat_message_store.py delete mode 100644 python/packages/redis/tests/test_redis_provider.py diff --git a/python/packages/ag-ui/tests/ag_ui/conftest.py b/python/packages/ag-ui/tests/ag_ui/conftest.py index 82f6267863..09a4ff57f1 100644 --- a/python/packages/ag-ui/tests/ag_ui/conftest.py +++ b/python/packages/ag-ui/tests/ag_ui/conftest.py @@ -11,7 +11,7 @@ from agent_framework import ( AgentResponse, AgentResponseUpdate, - AgentThread, + AgentSession, BaseChatClient, ChatOptions, ChatResponse, @@ -49,8 +49,8 @@ def __init__(self, stream_fn: StreamFn, response_fn: ResponseFn | None = None) - super().__init__(function_middleware=[]) self._stream_fn = stream_fn self._response_fn = response_fn - self.last_thread: AgentThread | None = None - self.last_service_thread_id: str | None = None + self.last_session: AgentSession | None = None + self.last_service_session_id: str | None = None @overload def get_response( @@ -90,8 +90,8 @@ def get_response( options: OptionsCoT | ChatOptions[Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: - self.last_thread = kwargs.get("thread") - self.last_service_thread_id = self.last_thread.service_thread_id if self.last_thread else None + self.last_session = kwargs.get("session") + self.last_service_session_id = self.last_session.service_session_id if self.last_session else None return cast( Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]], super().get_response( @@ -178,7 +178,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: Literal[False] = ..., - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]]: ... @@ -188,7 +188,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: Literal[True], - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... @@ -197,7 +197,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: if stream: @@ -218,8 +218,8 @@ async def _get_response() -> AgentResponse[Any]: return _get_response() - def get_new_thread(self, **kwargs: Any) -> AgentThread: - return AgentThread() + def create_session(self, **kwargs: Any) -> AgentSession: + return AgentSession() # Fixtures diff --git a/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py index f597c081f4..165756af39 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py +++ b/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py @@ -444,13 +444,7 @@ async def stream_fn( async for event in wrapper.run_agent(input_data): events.append(event) - # AG-UI internal metadata should be stored in thread.metadata - thread = agent.client.last_thread - thread_metadata = thread.metadata if thread and hasattr(thread, "metadata") else {} - assert thread_metadata.get("ag_ui_thread_id") == "test_thread_123" - assert thread_metadata.get("ag_ui_run_id") == "test_run_456" - - # Internal metadata should NOT be passed to chat client options + # AG-UI internal metadata should NOT be passed to chat client options options_metadata = captured_options.get("metadata", {}) assert "ag_ui_thread_id" not in options_metadata assert "ag_ui_run_id" not in options_metadata @@ -488,15 +482,7 @@ async def stream_fn( async for event in wrapper.run_agent(input_data): events.append(event) - # Current state should be stored in thread.metadata - thread = agent.client.last_thread - thread_metadata = thread.metadata if thread and hasattr(thread, "metadata") else {} - current_state = thread_metadata.get("current_state") - if isinstance(current_state, str): - current_state = json.loads(current_state) - assert current_state == {"document": "Test content"} - - # Internal metadata should NOT be passed to chat client options + # Current state should NOT be passed to chat client options options_metadata = captured_options.get("metadata", {}) assert "current_state" not in options_metadata @@ -612,10 +598,10 @@ async def stream_fn( async def test_agent_with_use_service_thread_is_false(streaming_chat_client_stub): - """Test that when use_service_thread is False, the AgentThread used to run the agent is NOT set to the service thread ID.""" + """Test that when use_service_thread is False, the AgentSession used to run the agent is NOT set to the service session ID.""" from agent_framework.ag_ui import AgentFrameworkAgent - request_service_thread_id: str | None = None + request_service_session_id: str | None = None async def stream_fn( messages: MutableSequence[Message], chat_options: ChatOptions, **kwargs: Any @@ -632,21 +618,21 @@ async def stream_fn( events: list[Any] = [] async for event in wrapper.run_agent(input_data): events.append(event) - assert request_service_thread_id is None # type: ignore[attr-defined] (service_thread_id should be set) + assert request_service_session_id is None # type: ignore[attr-defined] (service_session_id should be set) async def test_agent_with_use_service_thread_is_true(streaming_chat_client_stub): - """Test that when use_service_thread is True, the AgentThread used to run the agent is set to the service thread ID.""" + """Test that when use_service_thread is True, the AgentSession used to run the agent is set to the service session ID.""" from agent_framework.ag_ui import AgentFrameworkAgent - request_service_thread_id: str | None = None + request_service_session_id: str | None = None async def stream_fn( messages: MutableSequence[Message], chat_options: ChatOptions, **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - nonlocal request_service_thread_id - thread = kwargs.get("thread") - request_service_thread_id = thread.service_thread_id if thread else None + nonlocal request_service_session_id + session = kwargs.get("session") + request_service_session_id = session.service_session_id if session else None yield ChatResponseUpdate( contents=[Content.from_text(text="Response")], response_id="resp_67890", conversation_id="conv_12345" ) @@ -659,8 +645,8 @@ async def stream_fn( events: list[Any] = [] async for event in wrapper.run_agent(input_data): events.append(event) - request_service_thread_id = agent.client.last_service_thread_id - assert request_service_thread_id == "conv_123456" # type: ignore[attr-defined] (service_thread_id should be set) + request_service_session_id = agent.client.last_service_session_id + assert request_service_session_id == "conv_123456" # type: ignore[attr-defined] (service_session_id should be set) async def test_function_approval_mode_executes_tool(streaming_chat_client_stub): diff --git a/python/packages/azure-ai-search/tests/test_search_provider.py b/python/packages/azure-ai-search/tests/test_search_provider.py deleted file mode 100644 index bcdbb9b5ef..0000000000 --- a/python/packages/azure-ai-search/tests/test_search_provider.py +++ /dev/null @@ -1,1018 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. -# pyright: reportPrivateUsage=false - -import os -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from agent_framework import Context, Message -from agent_framework._settings import load_settings -from agent_framework.azure import AzureAISearchContextProvider, AzureAISearchSettings -from agent_framework.exceptions import ServiceInitializationError -from azure.core.credentials import AzureKeyCredential -from azure.core.exceptions import ResourceNotFoundError - - -@pytest.fixture -def mock_search_client() -> AsyncMock: - """Create a mock SearchClient.""" - mock_client = AsyncMock() - mock_client.search = AsyncMock() - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock() - return mock_client - - -@pytest.fixture -def mock_index_client() -> AsyncMock: - """Create a mock SearchIndexClient.""" - mock_client = AsyncMock() - mock_client.get_knowledge_source = AsyncMock() - mock_client.create_knowledge_source = AsyncMock() - mock_client.get_agent = AsyncMock() - mock_client.create_agent = AsyncMock() - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock() - return mock_client - - -@pytest.fixture -def sample_messages() -> list[Message]: - """Create sample chat messages for testing.""" - return [ - Message(role="user", text="What is in the documents?"), - ] - - -class TestAzureAISearchSettings: - """Test AzureAISearchSettings configuration.""" - - def test_settings_with_direct_values(self) -> None: - """Test settings with direct values.""" - settings = load_settings( - AzureAISearchSettings, - env_prefix="AZURE_SEARCH_", - endpoint="https://test.search.windows.net", - index_name="test-index", - api_key="test-key", - ) - assert settings["endpoint"] == "https://test.search.windows.net" - assert settings["index_name"] == "test-index" - assert settings["api_key"] == "test-key" - - def test_settings_with_env_file_path(self) -> None: - """Test settings with env_file_path parameter.""" - settings = load_settings( - AzureAISearchSettings, - env_prefix="AZURE_SEARCH_", - endpoint="https://test.search.windows.net", - index_name="test-index", - ) - assert settings["endpoint"] == "https://test.search.windows.net" - assert settings["index_name"] == "test-index" - - def test_provider_uses_settings_from_env(self) -> None: - """Test that provider creates settings internally from env.""" - provider = AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - index_name="test-index", - api_key="test-key", - ) - assert provider.endpoint == "https://test.search.windows.net" - assert provider.index_name == "test-index" - - def test_provider_missing_endpoint_raises_error(self) -> None: - """Test that provider raises ServiceInitializationError without endpoint.""" - # Use patch.dict to clear environment and pass env_file_path="" to prevent .env file loading - clean_env = {k: v for k, v in os.environ.items() if not k.startswith("AZURE_SEARCH_")} - with ( - patch.dict(os.environ, clean_env, clear=True), - pytest.raises(ServiceInitializationError, match="endpoint is required"), - ): - AzureAISearchContextProvider( - index_name="test-index", - api_key="test-key", - env_file_path="", # Disable .env file loading - ) - - def test_provider_missing_index_name_raises_error(self) -> None: - """Test that provider raises ServiceInitializationError without index_name.""" - # Use patch.dict to clear environment and pass env_file_path="" to prevent .env file loading - clean_env = {k: v for k, v in os.environ.items() if not k.startswith("AZURE_SEARCH_")} - with ( - patch.dict(os.environ, clean_env, clear=True), - pytest.raises(ServiceInitializationError, match="index name is required"), - ): - AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - api_key="test-key", - env_file_path="", # Disable .env file loading - ) - - def test_provider_missing_credential_raises_error(self) -> None: - """Test that provider raises ServiceInitializationError without credential.""" - # Use patch.dict to clear environment and pass env_file_path="" to prevent .env file loading - clean_env = {k: v for k, v in os.environ.items() if not k.startswith("AZURE_SEARCH_")} - with ( - patch.dict(os.environ, clean_env, clear=True), - pytest.raises(ServiceInitializationError, match="credential is required"), - ): - AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - index_name="test-index", - env_file_path="", # Disable .env file loading - ) - - -class TestSearchProviderInitialization: - """Test initialization and configuration of AzureAISearchContextProvider.""" - - def test_init_semantic_mode_minimal(self) -> None: - """Test initialization with minimal semantic mode parameters.""" - provider = AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - index_name="test-index", - api_key="test-key", - mode="semantic", - ) - assert provider.endpoint == "https://test.search.windows.net" - assert provider.index_name == "test-index" - assert provider.mode == "semantic" - assert provider.top_k == 5 - - def test_init_semantic_mode_with_vector_field_requires_embedding_function(self) -> None: - """Test that vector_field_name requires embedding_function.""" - with pytest.raises(ValueError, match="embedding_function is required"): - AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - index_name="test-index", - api_key="test-key", - mode="semantic", - vector_field_name="embedding", - ) - - def test_init_agentic_mode_with_kb_only(self) -> None: - """Test agentic mode with existing knowledge_base_name (simplest path).""" - # Clear environment to ensure no env vars interfere - clean_env = {k: v for k, v in os.environ.items() if not k.startswith("AZURE_SEARCH_")} - with patch.dict(os.environ, clean_env, clear=True): - provider = AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - api_key="test-key", - mode="agentic", - knowledge_base_name="test-kb", - env_file_path="", # Disable .env file loading - ) - assert provider.mode == "agentic" - assert provider.knowledge_base_name == "test-kb" - assert provider._use_existing_knowledge_base is True - - def test_init_agentic_mode_with_index_requires_model(self) -> None: - """Test that agentic mode with index_name requires model_deployment_name.""" - # Clear environment to ensure no env vars interfere - clean_env = {k: v for k, v in os.environ.items() if not k.startswith("AZURE_SEARCH_")} - with ( - patch.dict(os.environ, clean_env, clear=True), - pytest.raises(ServiceInitializationError, match="model_deployment_name"), - ): - AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - index_name="test-index", - api_key="test-key", - mode="agentic", - env_file_path="", # Disable .env file loading - ) - - def test_init_agentic_mode_with_index_and_model(self) -> None: - """Test agentic mode with index_name (auto-create KB path).""" - # Clear environment to ensure no env vars interfere - clean_env = {k: v for k, v in os.environ.items() if not k.startswith("AZURE_SEARCH_")} - with patch.dict(os.environ, clean_env, clear=True): - provider = AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - index_name="test-index", - api_key="test-key", - mode="agentic", - model_deployment_name="gpt-4o", - azure_openai_resource_url="https://test.openai.azure.com", - env_file_path="", # Disable .env file loading - ) - assert provider.mode == "agentic" - assert provider.index_name == "test-index" - assert provider.knowledge_base_name == "test-index-kb" # Auto-generated - assert provider._use_existing_knowledge_base is False - - def test_init_agentic_mode_rejects_both_index_and_kb(self) -> None: - """Test that agentic mode rejects both index_name AND knowledge_base_name.""" - # Clear environment to ensure no env vars interfere - clean_env = {k: v for k, v in os.environ.items() if not k.startswith("AZURE_SEARCH_")} - with ( - patch.dict(os.environ, clean_env, clear=True), - pytest.raises(ServiceInitializationError, match="either 'index_name' OR 'knowledge_base_name', not both"), - ): - AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - index_name="test-index", - api_key="test-key", - mode="agentic", - knowledge_base_name="test-kb", - model_deployment_name="gpt-4o", - azure_openai_resource_url="https://test.openai.azure.com", - env_file_path="", # Disable .env file loading - ) - - def test_init_agentic_mode_requires_index_or_kb(self) -> None: - """Test that agentic mode requires either index_name or knowledge_base_name.""" - # Clear environment to ensure no env vars interfere - clean_env = {k: v for k, v in os.environ.items() if not k.startswith("AZURE_SEARCH_")} - with ( - patch.dict(os.environ, clean_env, clear=True), - pytest.raises(ServiceInitializationError, match="provide either 'index_name'.*or 'knowledge_base_name'"), - ): - AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - api_key="test-key", - mode="agentic", - env_file_path="", # Disable .env file loading - ) - - def test_init_model_name_defaults_to_deployment_name(self) -> None: - """Test that model_name defaults to deployment_name if not provided.""" - # Clear environment to ensure no env vars interfere - clean_env = {k: v for k, v in os.environ.items() if not k.startswith("AZURE_SEARCH_")} - with patch.dict(os.environ, clean_env, clear=True): - provider = AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - api_key="test-key", - mode="agentic", - knowledge_base_name="test-kb", - model_deployment_name="gpt-4o", - env_file_path="", # Disable .env file loading - ) - assert provider.model_name == "gpt-4o" - - def test_init_with_custom_context_prompt(self) -> None: - """Test initialization with custom context prompt.""" - custom_prompt = "Use the following information:" - provider = AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - index_name="test-index", - api_key="test-key", - mode="semantic", - context_prompt=custom_prompt, - ) - assert provider.context_prompt == custom_prompt - - def test_init_uses_default_context_prompt(self) -> None: - """Test that default context prompt is used when not provided.""" - provider = AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - index_name="test-index", - api_key="test-key", - mode="semantic", - ) - assert provider.context_prompt == provider._DEFAULT_SEARCH_CONTEXT_PROMPT - - -class TestSemanticSearch: - """Test semantic search functionality.""" - - @pytest.mark.asyncio - @patch("agent_framework_azure_ai_search._search_provider.SearchClient") - async def test_semantic_search_basic(self, mock_search_class: MagicMock, sample_messages: list[Message]) -> None: - """Test basic semantic search without vector search.""" - # Setup mock - mock_search_client = AsyncMock() - mock_results = AsyncMock() - mock_results.__aiter__.return_value = iter([{"content": "Test document content"}]) - mock_search_client.search.return_value = mock_results - mock_search_class.return_value = mock_search_client - - provider = AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - index_name="test-index", - api_key="test-key", - mode="semantic", - ) - - context = await provider.invoking(sample_messages) - - assert isinstance(context, Context) - assert len(context.messages) > 1 # First message is prompt, rest are results - # First message should be the context prompt - assert "Use the following context" in context.messages[0].text - # Second message should contain the search result - assert "Test document content" in context.messages[1].text - - @pytest.mark.asyncio - @patch("agent_framework_azure_ai_search._search_provider.SearchClient") - async def test_semantic_search_empty_query(self, mock_search_class: MagicMock) -> None: - """Test that empty queries return empty context.""" - mock_search_client = AsyncMock() - mock_search_class.return_value = mock_search_client - - provider = AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - index_name="test-index", - api_key="test-key", - mode="semantic", - ) - - # Empty message - context = await provider.invoking([Message(role="user", text="")]) - - assert isinstance(context, Context) - assert len(context.messages) == 0 - - @pytest.mark.asyncio - @patch("agent_framework_azure_ai_search._search_provider.SearchClient") - async def test_semantic_search_with_vector_query( - self, mock_search_class: MagicMock, sample_messages: list[Message] - ) -> None: - """Test semantic search with vector query.""" - # Setup mock - mock_search_client = AsyncMock() - mock_results = AsyncMock() - mock_results.__aiter__.return_value = iter([{"content": "Vector search result"}]) - mock_search_client.search.return_value = mock_results - mock_search_class.return_value = mock_search_client - - # Mock embedding function - async def mock_embed(text: str) -> list[float]: - return [0.1, 0.2, 0.3] - - provider = AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - index_name="test-index", - api_key="test-key", - mode="semantic", - vector_field_name="embedding", - embedding_function=mock_embed, - ) - - context = await provider.invoking(sample_messages) - - assert isinstance(context, Context) - assert len(context.messages) > 0 - # Verify that search was called - mock_search_client.search.assert_called_once() - - -class TestKnowledgeBaseSetup: - """Test Knowledge Base setup for agentic mode.""" - - @pytest.mark.asyncio - @patch("agent_framework_azure_ai_search._search_provider.SearchIndexClient") - @patch("agent_framework_azure_ai_search._search_provider.SearchClient") - async def test_ensure_knowledge_base_creates_when_not_exists( - self, mock_search_class: MagicMock, mock_index_class: MagicMock - ) -> None: - """Test that Knowledge Base is created when it doesn't exist (index_name path).""" - # Setup mocks - mock_index_client = AsyncMock() - mock_index_client.get_knowledge_source.side_effect = ResourceNotFoundError("Not found") - mock_index_client.create_knowledge_source = AsyncMock() - mock_index_client.get_knowledge_base.side_effect = ResourceNotFoundError("Not found") - mock_index_client.create_or_update_knowledge_base = AsyncMock() - mock_index_class.return_value = mock_index_client - - mock_search_client = AsyncMock() - mock_search_class.return_value = mock_search_client - - # Clear environment to ensure no env vars interfere - clean_env = {k: v for k, v in os.environ.items() if not k.startswith("AZURE_SEARCH_")} - with patch.dict(os.environ, clean_env, clear=True): - # Use index_name path (auto-create KB) - provider = AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - index_name="test-index", - api_key="test-key", - mode="agentic", - model_deployment_name="gpt-4o", - azure_openai_resource_url="https://test.openai.azure.com", - env_file_path="", # Disable .env file loading - ) - - await provider._ensure_knowledge_base() - - # Verify knowledge source was created - mock_index_client.create_knowledge_source.assert_called_once() - # Verify Knowledge Base was created - mock_index_client.create_or_update_knowledge_base.assert_called_once() - - @pytest.mark.asyncio - @patch("agent_framework_azure_ai_search._search_provider.SearchIndexClient") - @patch("agent_framework_azure_ai_search._search_provider.SearchClient") - async def test_ensure_knowledge_base_skips_when_using_existing_kb( - self, mock_search_class: MagicMock, mock_index_class: MagicMock - ) -> None: - """Test that KB setup is skipped when using existing knowledge_base_name.""" - # Setup mocks - mock_index_client = AsyncMock() - mock_index_class.return_value = mock_index_client - - mock_search_client = AsyncMock() - mock_search_class.return_value = mock_search_client - - # Clear environment to ensure no env vars interfere - clean_env = {k: v for k, v in os.environ.items() if not k.startswith("AZURE_SEARCH_")} - with patch.dict(os.environ, clean_env, clear=True): - # Use knowledge_base_name path (existing KB) - provider = AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - api_key="test-key", - mode="agentic", - knowledge_base_name="test-kb", - env_file_path="", # Disable .env file loading - ) - - await provider._ensure_knowledge_base() - - # Verify nothing was created (using existing KB) - mock_index_client.create_knowledge_source.assert_not_called() - mock_index_client.create_or_update_knowledge_base.assert_not_called() - - -class TestContextProviderLifecycle: - """Test context provider lifecycle methods.""" - - @pytest.mark.asyncio - @patch("agent_framework_azure_ai_search._search_provider.SearchClient") - async def test_context_manager(self, mock_search_class: MagicMock) -> None: - """Test that provider can be used as async context manager.""" - mock_search_client = AsyncMock() - mock_search_class.return_value = mock_search_client - - async with AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - index_name="test-index", - api_key="test-key", - mode="semantic", - ) as provider: - assert provider is not None - assert isinstance(provider, AzureAISearchContextProvider) - - @pytest.mark.asyncio - @patch("agent_framework_azure_ai_search._search_provider.KnowledgeBaseRetrievalClient") - @patch("agent_framework_azure_ai_search._search_provider.SearchIndexClient") - @patch("agent_framework_azure_ai_search._search_provider.SearchClient") - async def test_context_manager_agentic_cleanup( - self, mock_search_class: MagicMock, mock_index_class: MagicMock, mock_retrieval_class: MagicMock - ) -> None: - """Test that agentic mode provider cleans up retrieval client.""" - mock_search_client = AsyncMock() - mock_search_class.return_value = mock_search_client - - mock_index_client = AsyncMock() - mock_index_class.return_value = mock_index_client - - mock_retrieval_client = AsyncMock() - mock_retrieval_client.close = AsyncMock() - mock_retrieval_class.return_value = mock_retrieval_client - - # Clear environment to ensure no env vars interfere - clean_env = {k: v for k, v in os.environ.items() if not k.startswith("AZURE_SEARCH_")} - with patch.dict(os.environ, clean_env, clear=True): - # Use knowledge_base_name path (existing KB) - async with AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - api_key="test-key", - mode="agentic", - knowledge_base_name="test-kb", - env_file_path="", # Disable .env file loading - ) as provider: - # Simulate retrieval client being created - provider._retrieval_client = mock_retrieval_client - - # Verify cleanup was called - mock_retrieval_client.close.assert_called_once() - - def test_string_api_key_conversion(self) -> None: - """Test that string api_key is converted to AzureKeyCredential.""" - provider = AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - index_name="test-index", - api_key="my-api-key", # String api_key - mode="semantic", - ) - assert isinstance(provider.credential, AzureKeyCredential) - - -class TestMessageFiltering: - """Test message filtering functionality.""" - - @pytest.mark.asyncio - @patch("agent_framework_azure_ai_search._search_provider.SearchClient") - async def test_filters_non_user_assistant_messages(self, mock_search_class: MagicMock) -> None: - """Test that only USER and ASSISTANT messages are processed.""" - # Setup mock - mock_search_client = AsyncMock() - mock_results = AsyncMock() - mock_results.__aiter__.return_value = iter([{"content": "Test result"}]) - mock_search_client.search.return_value = mock_results - mock_search_class.return_value = mock_search_client - - provider = AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - index_name="test-index", - api_key="test-key", - mode="semantic", - ) - - # Mix of message types - messages = [ - Message(role="system", text="System message"), - Message(role="user", text="User message"), - Message(role="assistant", text="Assistant message"), - Message(role="tool", text="Tool message"), - ] - - context = await provider.invoking(messages) - - # Should have processed only USER and ASSISTANT messages - assert isinstance(context, Context) - mock_search_client.search.assert_called_once() - - @pytest.mark.asyncio - @patch("agent_framework_azure_ai_search._search_provider.SearchClient") - async def test_filters_empty_messages(self, mock_search_class: MagicMock) -> None: - """Test that empty/whitespace messages are filtered out.""" - mock_search_client = AsyncMock() - mock_search_class.return_value = mock_search_client - - provider = AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - index_name="test-index", - api_key="test-key", - mode="semantic", - ) - - # Messages with empty/whitespace text - messages = [ - Message(role="user", text=""), - Message(role="user", text=" "), - Message(role="user", text=""), # Message with None text becomes empty string - ] - - context = await provider.invoking(messages) - - # Should return empty context - assert len(context.messages) == 0 - - -class TestCitations: - """Test citation functionality.""" - - @pytest.mark.asyncio - @patch("agent_framework_azure_ai_search._search_provider.SearchClient") - async def test_citations_included_in_semantic_search(self, mock_search_class: MagicMock) -> None: - """Test that citations are included in semantic search results.""" - # Setup mock with document ID - mock_search_client = AsyncMock() - mock_results = AsyncMock() - mock_doc = {"id": "doc123", "content": "Test document content"} - mock_results.__aiter__.return_value = iter([mock_doc]) - mock_search_client.search.return_value = mock_results - mock_search_class.return_value = mock_search_client - - provider = AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - index_name="test-index", - api_key="test-key", - mode="semantic", - ) - - context = await provider.invoking([Message(role="user", text="test query")]) - - # Check that citation is included - assert isinstance(context, Context) - assert len(context.messages) > 1 # First message is prompt, rest are results - # Citation should be in the result message (second message) - assert "[Source: doc123]" in context.messages[1].text - assert "Test document content" in context.messages[1].text - - -class TestAgenticSearch: - """Test agentic search functionality.""" - - @pytest.mark.asyncio - @patch("agent_framework_azure_ai_search._search_provider.KnowledgeBaseRetrievalClient") - @patch("agent_framework_azure_ai_search._search_provider.SearchIndexClient") - @patch("agent_framework_azure_ai_search._search_provider.SearchClient") - async def test_agentic_search_basic( - self, - mock_search_class: MagicMock, - mock_index_class: MagicMock, - mock_retrieval_class: MagicMock, - sample_messages: list[Message], - ) -> None: - """Test basic agentic search with Knowledge Base retrieval.""" - # Setup search client mock - mock_search_client = AsyncMock() - mock_search_class.return_value = mock_search_client - - # Setup index client mock - mock_index_client = AsyncMock() - mock_index_class.return_value = mock_index_client - - # Setup retrieval client mock with response - mock_retrieval_client = AsyncMock() - mock_response = MagicMock() - mock_message = MagicMock() - mock_content = MagicMock() - mock_content.text = "Agentic search result" - # Make it pass isinstance check - from agent_framework_azure_ai_search._search_provider import _agentic_retrieval_available - - if _agentic_retrieval_available: - from azure.search.documents.knowledgebases.models import KnowledgeBaseMessageTextContent - - mock_content.__class__ = KnowledgeBaseMessageTextContent - mock_message.content = [mock_content] - mock_response.response = [mock_message] - mock_retrieval_client.retrieve.return_value = mock_response - mock_retrieval_client.close = AsyncMock() - mock_retrieval_class.return_value = mock_retrieval_client - - # Clear environment to ensure no env vars interfere - clean_env = {k: v for k, v in os.environ.items() if not k.startswith("AZURE_SEARCH_")} - with patch.dict(os.environ, clean_env, clear=True): - # Use knowledge_base_name path (existing KB) - provider = AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - api_key="test-key", - mode="agentic", - knowledge_base_name="test-kb", - env_file_path="", # Disable .env file loading - ) - - context = await provider.invoking(sample_messages) - - assert isinstance(context, Context) - # Should have at least the prompt message - assert len(context.messages) >= 1 - - @pytest.mark.asyncio - @patch("agent_framework_azure_ai_search._search_provider.KnowledgeBaseRetrievalClient") - @patch("agent_framework_azure_ai_search._search_provider.SearchIndexClient") - @patch("agent_framework_azure_ai_search._search_provider.SearchClient") - async def test_agentic_search_no_results( - self, - mock_search_class: MagicMock, - mock_index_class: MagicMock, - mock_retrieval_class: MagicMock, - sample_messages: list[Message], - ) -> None: - """Test agentic search when no results are returned.""" - # Setup mocks - mock_search_client = AsyncMock() - mock_search_class.return_value = mock_search_client - - mock_index_client = AsyncMock() - mock_index_class.return_value = mock_index_client - - # Empty response - mock_retrieval_client = AsyncMock() - mock_response = MagicMock() - mock_response.response = [] - mock_retrieval_client.retrieve.return_value = mock_response - mock_retrieval_client.close = AsyncMock() - mock_retrieval_class.return_value = mock_retrieval_client - - # Clear environment to ensure no env vars interfere - clean_env = {k: v for k, v in os.environ.items() if not k.startswith("AZURE_SEARCH_")} - with patch.dict(os.environ, clean_env, clear=True): - # Use knowledge_base_name path (existing KB) - provider = AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - api_key="test-key", - mode="agentic", - knowledge_base_name="test-kb", - env_file_path="", # Disable .env file loading - ) - - context = await provider.invoking(sample_messages) - - assert isinstance(context, Context) - # Should have fallback message - assert len(context.messages) >= 1 - - @pytest.mark.asyncio - @patch("agent_framework_azure_ai_search._search_provider.KnowledgeBaseRetrievalClient") - @patch("agent_framework_azure_ai_search._search_provider.SearchIndexClient") - @patch("agent_framework_azure_ai_search._search_provider.SearchClient") - async def test_agentic_search_with_medium_reasoning( - self, - mock_search_class: MagicMock, - mock_index_class: MagicMock, - mock_retrieval_class: MagicMock, - sample_messages: list[Message], - ) -> None: - """Test agentic search with medium reasoning effort.""" - # Setup mocks - mock_search_client = AsyncMock() - mock_search_class.return_value = mock_search_client - - mock_index_client = AsyncMock() - mock_index_class.return_value = mock_index_client - - mock_retrieval_client = AsyncMock() - mock_response = MagicMock() - mock_message = MagicMock() - mock_content = MagicMock() - mock_content.text = "Medium reasoning result" - from agent_framework_azure_ai_search._search_provider import _agentic_retrieval_available - - if _agentic_retrieval_available: - from azure.search.documents.knowledgebases.models import KnowledgeBaseMessageTextContent - - mock_content.__class__ = KnowledgeBaseMessageTextContent - mock_message.content = [mock_content] - mock_response.response = [mock_message] - mock_retrieval_client.retrieve.return_value = mock_response - mock_retrieval_client.close = AsyncMock() - mock_retrieval_class.return_value = mock_retrieval_client - - # Clear environment to ensure no env vars interfere - clean_env = {k: v for k, v in os.environ.items() if not k.startswith("AZURE_SEARCH_")} - with patch.dict(os.environ, clean_env, clear=True): - # Use knowledge_base_name path (existing KB) - provider = AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - api_key="test-key", - mode="agentic", - knowledge_base_name="test-kb", - retrieval_reasoning_effort="medium", # Test medium reasoning - env_file_path="", # Disable .env file loading - ) - - context = await provider.invoking(sample_messages) - - assert isinstance(context, Context) - assert len(context.messages) >= 1 - - -class TestVectorFieldAutoDiscovery: - """Test vector field auto-discovery functionality.""" - - @pytest.mark.asyncio - @patch("agent_framework_azure_ai_search._search_provider.SearchIndexClient") - @patch("agent_framework_azure_ai_search._search_provider.SearchClient") - async def test_auto_discovers_single_vector_field( - self, mock_search_class: MagicMock, mock_index_class: MagicMock - ) -> None: - """Test that single vector field is auto-discovered.""" - # Setup search client mock - mock_search_client = AsyncMock() - mock_search_class.return_value = mock_search_client - - # Setup index client mock - mock_index_client = AsyncMock() - mock_index = MagicMock() - - # Create mock field with vector_search_dimensions attribute - mock_vector_field = MagicMock() - mock_vector_field.name = "embedding_vector" - mock_vector_field.vector_search_dimensions = 1536 - - mock_index.fields = [mock_vector_field] - mock_index_client.get_index.return_value = mock_index - mock_index_client.close = AsyncMock() - mock_index_class.return_value = mock_index_client - - # Create provider without specifying vector_field_name - provider = AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - index_name="test-index", - api_key="test-key", - mode="semantic", - ) - - # Trigger auto-discovery - await provider._auto_discover_vector_field() - - # Vector field should be auto-discovered but not used without embedding function - assert provider._auto_discovered_vector_field is True - # Should be cleared since no embedding function - assert provider.vector_field_name is None - - @pytest.mark.asyncio - async def test_vector_detection_accuracy(self) -> None: - """Test that vector field detection logic correctly identifies vector fields.""" - from azure.search.documents.indexes.models import SearchField - - # Create real SearchField objects to test the detection logic - vector_field = SearchField( - name="embedding_vector", type="Collection(Edm.Single)", vector_search_dimensions=1536, searchable=True - ) - - string_field = SearchField(name="content", type="Edm.String", searchable=True) - - number_field = SearchField(name="price", type="Edm.Double", filterable=True) - - # Test detection logic directly - is_vector_1 = vector_field.vector_search_dimensions is not None and vector_field.vector_search_dimensions > 0 - is_vector_2 = string_field.vector_search_dimensions is not None and string_field.vector_search_dimensions > 0 - is_vector_3 = number_field.vector_search_dimensions is not None and number_field.vector_search_dimensions > 0 - - # Only the vector field should be detected - assert is_vector_1 is True - assert is_vector_2 is False - assert is_vector_3 is False - - @pytest.mark.asyncio - @patch("agent_framework_azure_ai_search._search_provider.SearchIndexClient") - @patch("agent_framework_azure_ai_search._search_provider.SearchClient") - async def test_no_false_positives_on_string_fields( - self, mock_search_class: MagicMock, mock_index_class: MagicMock - ) -> None: - """Test that regular string fields are not detected as vector fields.""" - # Setup search client mock - mock_search_client = AsyncMock() - mock_search_class.return_value = mock_search_client - - # Setup index with only string fields (no vectors) - mock_index_client = AsyncMock() - mock_index = MagicMock() - - # All fields have vector_search_dimensions = None - mock_fields = [] - for name in ["id", "title", "content", "category"]: - field = MagicMock() - field.name = name - field.vector_search_dimensions = None - field.vector_search_profile_name = None - mock_fields.append(field) - - mock_index.fields = mock_fields - mock_index_client.get_index.return_value = mock_index - mock_index_client.close = AsyncMock() - mock_index_class.return_value = mock_index_client - - # Create provider - provider = AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - index_name="test-index", - api_key="test-key", - mode="semantic", - ) - - # Trigger auto-discovery - await provider._auto_discover_vector_field() - - # Should NOT detect any vector fields - assert provider.vector_field_name is None - assert provider._auto_discovered_vector_field is True - - @pytest.mark.asyncio - @patch("agent_framework_azure_ai_search._search_provider.SearchIndexClient") - @patch("agent_framework_azure_ai_search._search_provider.SearchClient") - async def test_multiple_vector_fields_without_vectorizer( - self, mock_search_class: MagicMock, mock_index_class: MagicMock - ) -> None: - """Test that multiple vector fields without vectorizer logs warning and uses keyword search.""" - # Setup search client mock - mock_search_client = AsyncMock() - mock_search_class.return_value = mock_search_client - - # Setup index with multiple vector fields (no vectorizers) - mock_index_client = AsyncMock() - mock_index = MagicMock() - - # Multiple vector fields - mock_fields = [] - for name in ["embedding1", "embedding2"]: - field = MagicMock() - field.name = name - field.vector_search_dimensions = 1536 - field.vector_search_profile_name = None # No vectorizer - mock_fields.append(field) - - mock_index.fields = mock_fields - mock_index.vector_search = None # No vector search config - mock_index_client.get_index.return_value = mock_index - mock_index_client.close = AsyncMock() - mock_index_class.return_value = mock_index_client - - # Create provider - provider = AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - index_name="test-index", - api_key="test-key", - mode="semantic", - ) - - # Trigger auto-discovery - await provider._auto_discover_vector_field() - - # Should NOT use any vector field (multiple fields, can't choose) - assert provider.vector_field_name is None - assert provider._auto_discovered_vector_field is True - - @pytest.mark.asyncio - @patch("agent_framework_azure_ai_search._search_provider.SearchIndexClient") - @patch("agent_framework_azure_ai_search._search_provider.SearchClient") - async def test_multiple_vectorizable_fields( - self, mock_search_class: MagicMock, mock_index_class: MagicMock - ) -> None: - """Test that multiple vectorizable fields logs warning and uses keyword search.""" - # Setup search client mock - mock_search_client = AsyncMock() - mock_search_class.return_value = mock_search_client - - # Setup index with multiple vectorizable fields - mock_index_client = AsyncMock() - mock_index = MagicMock() - - # Multiple vector fields with vectorizers - mock_fields = [] - for name in ["embedding1", "embedding2"]: - field = MagicMock() - field.name = name - field.vector_search_dimensions = 1536 - field.vector_search_profile_name = f"{name}-profile" - mock_fields.append(field) - - mock_index.fields = mock_fields - - # Setup vector search config with profiles that have vectorizers - mock_profile1 = MagicMock() - mock_profile1.name = "embedding1-profile" - mock_profile1.vectorizer_name = "vectorizer1" - - mock_profile2 = MagicMock() - mock_profile2.name = "embedding2-profile" - mock_profile2.vectorizer_name = "vectorizer2" - - mock_index.vector_search = MagicMock() - mock_index.vector_search.profiles = [mock_profile1, mock_profile2] - - mock_index_client.get_index.return_value = mock_index - mock_index_client.close = AsyncMock() - mock_index_class.return_value = mock_index_client - - # Create provider - provider = AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - index_name="test-index", - api_key="test-key", - mode="semantic", - ) - - # Trigger auto-discovery - await provider._auto_discover_vector_field() - - # Should NOT use any vector field (multiple vectorizable fields, can't choose) - assert provider.vector_field_name is None - assert provider._auto_discovered_vector_field is True - - @pytest.mark.asyncio - @patch("agent_framework_azure_ai_search._search_provider.SearchIndexClient") - @patch("agent_framework_azure_ai_search._search_provider.SearchClient") - async def test_single_vectorizable_field_detected( - self, mock_search_class: MagicMock, mock_index_class: MagicMock - ) -> None: - """Test that single vectorizable field is auto-detected for server-side vectorization.""" - # Setup search client mock - mock_search_client = AsyncMock() - mock_search_class.return_value = mock_search_client - - # Setup index with single vectorizable field - mock_index_client = AsyncMock() - mock_index = MagicMock() - - # Single vector field with vectorizer - mock_field = MagicMock() - mock_field.name = "embedding" - mock_field.vector_search_dimensions = 1536 - mock_field.vector_search_profile_name = "embedding-profile" - - mock_index.fields = [mock_field] - - # Setup vector search config with profile that has vectorizer - mock_profile = MagicMock() - mock_profile.name = "embedding-profile" - mock_profile.vectorizer_name = "openai-vectorizer" - - mock_index.vector_search = MagicMock() - mock_index.vector_search.profiles = [mock_profile] - - mock_index_client.get_index.return_value = mock_index - mock_index_client.close = AsyncMock() - mock_index_class.return_value = mock_index_client - - # Create provider - provider = AzureAISearchContextProvider( - endpoint="https://test.search.windows.net", - index_name="test-index", - api_key="test-key", - mode="semantic", - ) - - # Trigger auto-discovery - await provider._auto_discover_vector_field() - - # Should detect the vectorizable field - assert provider.vector_field_name == "embedding" - assert provider._auto_discovered_vector_field is True - assert provider._use_vectorizable_query is True # Server-side vectorization diff --git a/python/packages/azure-ai/tests/test_azure_ai_agent_client.py b/python/packages/azure-ai/tests/test_azure_ai_agent_client.py index 80c99f6e89..8ca3f0dc56 100644 --- a/python/packages/azure-ai/tests/test_azure_ai_agent_client.py +++ b/python/packages/azure-ai/tests/test_azure_ai_agent_client.py @@ -11,7 +11,7 @@ Agent, AgentResponse, AgentResponseUpdate, - AgentThread, + AgentSession, ChatOptions, ChatResponse, ChatResponseUpdate, @@ -1524,24 +1524,24 @@ async def test_azure_ai_chat_client_agent_basic_run_streaming() -> None: @pytest.mark.flaky @skip_if_azure_ai_integration_tests_disabled async def test_azure_ai_chat_client_agent_thread_persistence() -> None: - """Test Agent thread persistence across runs with AzureAIAgentClient.""" + """Test Agent session persistence across runs with AzureAIAgentClient.""" async with Agent( client=AzureAIAgentClient(credential=AzureCliCredential()), instructions="You are a helpful assistant with good memory.", ) as agent: - # Create a new thread that will be reused - thread = agent.get_new_thread() + # Create a new session that will be reused + session = agent.create_session() # First message - establish context first_response = await agent.run( - "Remember this number: 42. What number did I just tell you to remember?", thread=thread + "Remember this number: 42. What number did I just tell you to remember?", session=session ) assert isinstance(first_response, AgentResponse) assert "42" in first_response.text # Second message - test conversation memory second_response = await agent.run( - "What number did I tell you to remember in my previous message?", thread=thread + "What number did I tell you to remember in my previous message?", session=session ) assert isinstance(second_response, AgentResponse) assert "42" in second_response.text @@ -1555,16 +1555,16 @@ async def test_azure_ai_chat_client_agent_existing_thread_id() -> None: client=AzureAIAgentClient(credential=AzureCliCredential()), instructions="You are a helpful assistant with good memory.", ) as first_agent: - # Start a conversation and get the thread ID - thread = first_agent.get_new_thread() - first_response = await first_agent.run("My name is Alice. Remember this.", thread=thread) + # Start a conversation and get the session ID + session = first_agent.create_session() + first_response = await first_agent.run("My name is Alice. Remember this.", session=session) # Validate first response assert isinstance(first_response, AgentResponse) assert first_response.text is not None # The thread ID is set after the first response - existing_thread_id = thread.service_thread_id + existing_thread_id = session.service_session_id assert existing_thread_id is not None # Now continue with the same thread ID in a new agent instance @@ -1572,11 +1572,11 @@ async def test_azure_ai_chat_client_agent_existing_thread_id() -> None: client=AzureAIAgentClient(thread_id=existing_thread_id, credential=AzureCliCredential()), instructions="You are a helpful assistant with good memory.", ) as second_agent: - # Create a thread with the existing ID - thread = AgentThread(service_thread_id=existing_thread_id) + # Create a session with the existing ID + session = AgentSession(service_session_id=existing_thread_id) # Ask about the previous conversation - response2 = await second_agent.run("What is my name?", thread=thread) + response2 = await second_agent.run("What is my name?", session=session) # Validate that the agent remembers the previous conversation assert isinstance(response2, AgentResponse) diff --git a/python/packages/claude/tests/test_claude_agent.py b/python/packages/claude/tests/test_claude_agent.py index cc28df81c3..13e625b793 100644 --- a/python/packages/claude/tests/test_claude_agent.py +++ b/python/packages/claude/tests/test_claude_agent.py @@ -4,7 +4,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from agent_framework import AgentResponseUpdate, AgentThread, Content, Message, tool +from agent_framework import AgentResponseUpdate, AgentSession, Content, Message, tool from agent_framework._settings import load_settings from agent_framework_claude import ClaudeAgent, ClaudeAgentOptions, ClaudeAgentSettings @@ -267,12 +267,12 @@ async def test_run_captures_session_id(self) -> None: with patch("agent_framework_claude._agent.ClaudeSDKClient", return_value=mock_client): agent = ClaudeAgent() - thread = agent.get_new_thread() - await agent.run("Hello", thread=thread) - assert thread.service_thread_id == "test-session-id" + session = agent.create_session() + await agent.run("Hello", session=session) + assert session.service_session_id == "test-session-id" - async def test_run_with_thread(self) -> None: - """Test run with existing thread.""" + async def test_run_with_session(self) -> None: + """Test run with existing session.""" from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock from claude_agent_sdk.types import StreamEvent @@ -302,9 +302,9 @@ async def test_run_with_thread(self) -> None: with patch("agent_framework_claude._agent.ClaudeSDKClient", return_value=mock_client): agent = ClaudeAgent() - thread = agent.get_new_thread() - thread.service_thread_id = "existing-session" - await agent.run("Hello", thread=thread) + session = agent.create_session() + session.service_session_id = "existing-session" + await agent.run("Hello", session=session) # region Test ClaudeAgent Run Stream @@ -440,26 +440,25 @@ async def test_run_stream_raises_on_result_message_error(self) -> None: class TestClaudeAgentSessionManagement: """Tests for ClaudeAgent session management.""" - def test_get_new_thread(self) -> None: - """Test get_new_thread creates a new thread.""" + def test_create_session(self) -> None: + """Test create_session creates a new session.""" agent = ClaudeAgent() - thread = agent.get_new_thread() - assert isinstance(thread, AgentThread) - assert thread.service_thread_id is None + session = agent.create_session() + assert isinstance(session, AgentSession) + assert session.service_session_id is None - def test_get_new_thread_with_service_thread_id(self) -> None: - """Test get_new_thread with existing service_thread_id.""" + def test_create_session_with_service_session_id(self) -> None: + """Test create_session with existing service_session_id.""" agent = ClaudeAgent() - thread = agent.get_new_thread(service_thread_id="existing-session-123") - assert isinstance(thread, AgentThread) - assert thread.service_thread_id == "existing-session-123" + session = agent.create_session(session_id="existing-session-123") + assert isinstance(session, AgentSession) - def test_thread_inherits_context_provider(self) -> None: - """Test that thread inherits context provider.""" + def test_session_inherits_context_provider(self) -> None: + """Test that session inherits context provider.""" mock_provider = MagicMock() - agent = ClaudeAgent(context_provider=mock_provider) - thread = agent.get_new_thread() - assert thread.context_provider == mock_provider + agent = ClaudeAgent(context_providers=[mock_provider]) + session = agent.create_session() + assert mock_provider in agent.context_providers async def test_ensure_session_creates_client(self) -> None: """Test _ensure_session creates client when not started.""" diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 10c5bc4b58..d99f14eefb 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -675,7 +675,7 @@ def __init__( self.default_options: dict[str, Any] = { "model_id": opts.pop("model_id", None) or (getattr(self.client, "model_id", None)), "allow_multiple_tool_calls": opts.pop("allow_multiple_tool_calls", None), - "conversation_id": conversation_id, + "conversation_id": opts.pop("conversation_id", None), "frequency_penalty": opts.pop("frequency_penalty", None), "instructions": instructions_, "logit_bias": opts.pop("logit_bias", None), diff --git a/python/packages/core/agent_framework/mem0/__init__.py b/python/packages/core/agent_framework/mem0/__init__.py index dd28c5459b..dddc742ef0 100644 --- a/python/packages/core/agent_framework/mem0/__init__.py +++ b/python/packages/core/agent_framework/mem0/__init__.py @@ -5,7 +5,7 @@ IMPORT_PATH = "agent_framework_mem0" PACKAGE_NAME = "agent-framework-mem0" -_IMPORTS = ["__version__", "Mem0Provider"] +_IMPORTS = ["__version__", "Mem0ContextProvider"] def __getattr__(name: str) -> Any: diff --git a/python/packages/core/agent_framework/redis/__init__.py b/python/packages/core/agent_framework/redis/__init__.py index 85594715cb..9f96b3455f 100644 --- a/python/packages/core/agent_framework/redis/__init__.py +++ b/python/packages/core/agent_framework/redis/__init__.py @@ -5,7 +5,7 @@ IMPORT_PATH = "agent_framework_redis" PACKAGE_NAME = "agent-framework-redis" -_IMPORTS = ["__version__", "RedisProvider", "RedisChatMessageStore"] +_IMPORTS = ["__version__", "RedisContextProvider", "RedisHistoryProvider"] def __getattr__(name: str) -> Any: diff --git a/python/packages/core/tests/azure/test_azure_assistants_client.py b/python/packages/core/tests/azure/test_azure_assistants_client.py index bffa678a33..bb93ca71d4 100644 --- a/python/packages/core/tests/azure/test_azure_assistants_client.py +++ b/python/packages/core/tests/azure/test_azure_assistants_client.py @@ -12,7 +12,7 @@ Agent, AgentResponse, AgentResponseUpdate, - AgentThread, + AgentSession, ChatResponse, ChatResponseUpdate, Message, @@ -439,25 +439,25 @@ async def test_azure_assistants_agent_thread_persistence(): client=AzureOpenAIAssistantsClient(credential=AzureCliCredential()), instructions="You are a helpful assistant with good memory.", ) as agent: - # Create a new thread that will be reused - thread = agent.get_new_thread() + # Create a new session that will be reused + session = agent.create_session() # First message - establish context first_response = await agent.run( - "Remember this number: 42. What number did I just tell you to remember?", thread=thread + "Remember this number: 42. What number did I just tell you to remember?", session=session ) assert isinstance(first_response, AgentResponse) assert "42" in first_response.text # Second message - test conversation memory second_response = await agent.run( - "What number did I tell you to remember in my previous message?", thread=thread + "What number did I tell you to remember in my previous message?", session=session ) assert isinstance(second_response, AgentResponse) assert "42" in second_response.text - # Verify thread has been populated with conversation ID - assert thread.service_thread_id is not None + # Verify session has been populated with conversation ID + assert session.service_session_id is not None @pytest.mark.flaky @@ -472,17 +472,17 @@ async def test_azure_assistants_agent_existing_thread_id(): instructions="You are a helpful weather agent.", tools=[get_weather], ) as agent: - # Start a conversation and get the thread ID - thread = agent.get_new_thread() - response1 = await agent.run("What's the weather in Paris?", thread=thread) + # Start a conversation and get the session ID + session = agent.create_session() + response1 = await agent.run("What's the weather in Paris?", session=session) # Validate first response assert isinstance(response1, AgentResponse) assert response1.text is not None assert any(word in response1.text.lower() for word in ["weather", "paris"]) - # The thread ID is set after the first response - existing_thread_id = thread.service_thread_id + # The session ID is set after the first response + existing_thread_id = session.service_session_id assert existing_thread_id is not None # Now continue with the same thread ID in a new agent instance @@ -492,11 +492,11 @@ async def test_azure_assistants_agent_existing_thread_id(): instructions="You are a helpful weather agent.", tools=[get_weather], ) as agent: - # Create a thread with the existing ID - thread = AgentThread(service_thread_id=existing_thread_id) + # Create a session with the existing ID + session = AgentSession(service_session_id=existing_thread_id) # Ask about the previous conversation - response2 = await agent.run("What was the last city I asked about?", thread=thread) + response2 = await agent.run("What was the last city I asked about?", session=session) # Validate that the agent remembers the previous conversation assert isinstance(response2, AgentResponse) diff --git a/python/packages/core/tests/azure/test_azure_chat_client.py b/python/packages/core/tests/azure/test_azure_chat_client.py index 7618755a76..2c4cf331b3 100644 --- a/python/packages/core/tests/azure/test_azure_chat_client.py +++ b/python/packages/core/tests/azure/test_azure_chat_client.py @@ -806,17 +806,17 @@ async def test_azure_openai_chat_client_agent_thread_persistence(): client=AzureOpenAIChatClient(credential=AzureCliCredential()), instructions="You are a helpful assistant with good memory.", ) as agent: - # Create a new thread that will be reused - thread = agent.get_new_thread() + # Create a new session that will be reused + session = agent.create_session() # First interaction - response1 = await agent.run("My name is Alice. Remember this.", thread=thread) + response1 = await agent.run("My name is Alice. Remember this.", session=session) assert isinstance(response1, AgentResponse) assert response1.text is not None # Second interaction - test memory - response2 = await agent.run("What is my name?", thread=thread) + response2 = await agent.run("What is my name?", session=session) assert isinstance(response2, AgentResponse) assert response2.text is not None @@ -827,31 +827,31 @@ async def test_azure_openai_chat_client_agent_thread_persistence(): @skip_if_azure_integration_tests_disabled async def test_azure_openai_chat_client_agent_existing_thread(): """Test Azure OpenAI chat client agent with existing thread to continue conversations across agent instances.""" - # First conversation - capture the thread - preserved_thread = None + # First conversation - capture the session + preserved_session = None async with Agent( client=AzureOpenAIChatClient(credential=AzureCliCredential()), instructions="You are a helpful assistant with good memory.", ) as first_agent: - # Start a conversation and capture the thread - thread = first_agent.get_new_thread() - first_response = await first_agent.run("My name is Alice. Remember this.", thread=thread) + # Start a conversation and capture the session + session = first_agent.create_session() + first_response = await first_agent.run("My name is Alice. Remember this.", session=session) assert isinstance(first_response, AgentResponse) assert first_response.text is not None - # Preserve the thread for reuse - preserved_thread = thread + # Preserve the session for reuse + preserved_session = session - # Second conversation - reuse the thread in a new agent instance - if preserved_thread: + # Second conversation - reuse the session in a new agent instance + if preserved_session: async with Agent( client=AzureOpenAIChatClient(credential=AzureCliCredential()), instructions="You are a helpful assistant with good memory.", ) as second_agent: - # Reuse the preserved thread - second_response = await second_agent.run("What is my name?", thread=preserved_thread) + # Reuse the preserved session + second_response = await second_agent.run("What is my name?", session=preserved_session) assert isinstance(second_response, AgentResponse) assert second_response.text is not None diff --git a/python/packages/core/tests/azure/test_azure_responses_client.py b/python/packages/core/tests/azure/test_azure_responses_client.py index 1d40c769db..d4705b3aab 100644 --- a/python/packages/core/tests/azure/test_azure_responses_client.py +++ b/python/packages/core/tests/azure/test_azure_responses_client.py @@ -539,31 +539,31 @@ async def test_integration_client_agent_hosted_code_interpreter_tool(): @skip_if_azure_integration_tests_disabled async def test_integration_client_agent_existing_thread(): """Test Azure Responses Client agent with existing thread to continue conversations across agent instances.""" - # First conversation - capture the thread - preserved_thread = None + # First conversation - capture the session + preserved_session = None async with Agent( client=AzureOpenAIResponsesClient(credential=AzureCliCredential()), instructions="You are a helpful assistant with good memory.", ) as first_agent: - # Start a conversation and capture the thread - thread = first_agent.get_new_thread() - first_response = await first_agent.run("My hobby is photography. Remember this.", thread=thread, store=True) + # Start a conversation and capture the session + session = first_agent.create_session() + first_response = await first_agent.run("My hobby is photography. Remember this.", session=session, store=True) assert isinstance(first_response, AgentResponse) assert first_response.text is not None - # Preserve the thread for reuse - preserved_thread = thread + # Preserve the session for reuse + preserved_session = session - # Second conversation - reuse the thread in a new agent instance - if preserved_thread: + # Second conversation - reuse the session in a new agent instance + if preserved_session: async with Agent( client=AzureOpenAIResponsesClient(credential=AzureCliCredential()), instructions="You are a helpful assistant with good memory.", ) as second_agent: - # Reuse the preserved thread - second_response = await second_agent.run("What is my hobby?", thread=preserved_thread) + # Reuse the preserved session + second_response = await second_agent.run("What is my hobby?", session=preserved_session) assert isinstance(second_response, AgentResponse) assert second_response.text is not None diff --git a/python/packages/core/tests/core/conftest.py b/python/packages/core/tests/core/conftest.py index 60df064d19..2d1eec2d9a 100644 --- a/python/packages/core/tests/core/conftest.py +++ b/python/packages/core/tests/core/conftest.py @@ -13,7 +13,7 @@ from agent_framework import ( AgentResponse, AgentResponseUpdate, - AgentThread, + AgentSession, BaseChatClient, ChatMiddlewareLayer, ChatResponse, @@ -261,7 +261,7 @@ def chat_client_base(enable_function_calling: bool, max_iterations: int) -> Mock # region Agents -class MockAgentThread(AgentThread): +class MockAgentSession(AgentSession): pass @@ -284,41 +284,41 @@ def run( self, messages: str | Message | list[str] | list[Message] | None = None, *, - thread: AgentThread | None = None, + session: AgentSession | None = None, stream: bool = False, **kwargs: Any, ) -> Awaitable[AgentResponse] | AsyncIterable[AgentResponseUpdate]: if stream: - return self._run_stream_impl(messages=messages, thread=thread, **kwargs) - return self._run_impl(messages=messages, thread=thread, **kwargs) + return self._run_stream_impl(messages=messages, session=session, **kwargs) + return self._run_impl(messages=messages, session=session, **kwargs) async def _run_impl( self, messages: str | Message | list[str] | list[Message] | None = None, *, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> AgentResponse: - logger.debug(f"Running mock agent, with: {messages=}, {thread=}, {kwargs=}") + logger.debug(f"Running mock agent, with: {messages=}, {session=}, {kwargs=}") return AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Response")])]) async def _run_stream_impl( self, messages: str | Message | list[str] | list[Message] | None = None, *, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> AsyncIterable[AgentResponseUpdate]: - logger.debug(f"Running mock agent stream, with: {messages=}, {thread=}, {kwargs=}") + logger.debug(f"Running mock agent stream, with: {messages=}, {session=}, {kwargs=}") yield AgentResponseUpdate(contents=[Content.from_text("Response")]) - def get_new_thread(self) -> AgentThread: - return MockAgentThread() + def create_session(self) -> AgentSession: + return MockAgentSession() @fixture -def agent_thread() -> AgentThread: - return MockAgentThread() +def agent_session() -> AgentSession: + return MockAgentSession() @fixture diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index fcb4542a24..c2976c85bc 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -13,13 +13,11 @@ Agent, AgentResponse, AgentResponseUpdate, - AgentThread, - ChatMessageStore, + AgentSession, ChatOptions, ChatResponse, Content, - Context, - ContextProvider, + BaseContextProvider, FunctionTool, Message, SupportsAgentRun, @@ -28,11 +26,10 @@ ) from agent_framework._agents import _merge_options, _sanitize_agent_name from agent_framework._mcp import MCPTool -from agent_framework.exceptions import AgentExecutionException, AgentInitializationError -def test_agent_thread_type(agent_thread: AgentThread) -> None: - assert isinstance(agent_thread, AgentThread) +def test_agent_session_type(agent_session: AgentSession) -> None: + assert isinstance(agent_session, AgentSession) def test_agent_type(agent: SupportsAgentRun) -> None: @@ -93,38 +90,42 @@ async def test_chat_client_agent_run_streaming(client: SupportsChatGetResponse) assert result.text == "test streaming response another update" -async def test_chat_client_agent_get_new_thread(client: SupportsChatGetResponse) -> None: +async def test_chat_client_agent_create_session(client: SupportsChatGetResponse) -> None: agent = Agent(client=client) - thread = agent.get_new_thread() + session = agent.create_session() - assert isinstance(thread, AgentThread) + assert isinstance(session, AgentSession) -async def test_chat_client_agent_prepare_thread_and_messages(client: SupportsChatGetResponse) -> None: - agent = Agent(client=client) +async def test_chat_client_agent_prepare_session_and_messages(client: SupportsChatGetResponse) -> None: + from agent_framework._sessions import InMemoryHistoryProvider + + agent = Agent(client=client, context_providers=[InMemoryHistoryProvider("memory")]) message = Message(role="user", text="Hello") - thread = AgentThread(message_store=ChatMessageStore(messages=[message])) + session = AgentSession() + session.state["memory"] = {"messages": [message]} - _, _, result_messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage] - thread=thread, + session_context, _ = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + session=session, input_messages=[Message(role="user", text="Test")], ) + result_messages = session_context.get_messages(include_input=True) assert len(result_messages) == 2 - assert result_messages[0] == message + assert result_messages[0].text == "Hello" assert result_messages[1].text == "Test" -async def test_prepare_thread_does_not_mutate_agent_chat_options(client: SupportsChatGetResponse) -> None: +async def test_prepare_session_does_not_mutate_agent_chat_options(client: SupportsChatGetResponse) -> None: tool = {"type": "code_interpreter"} agent = Agent(client=client, tools=[tool]) assert agent.default_options.get("tools") is not None base_tools = agent.default_options["tools"] - thread = agent.get_new_thread() + session = agent.create_session() - _, prepared_chat_options, _ = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage] - thread=thread, + _, prepared_chat_options = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + session=session, input_messages=[Message(role="user", text="Test")], ) @@ -135,7 +136,7 @@ async def test_prepare_thread_does_not_mutate_agent_chat_options(client: Support assert len(agent.default_options["tools"]) == 1 -async def test_chat_client_agent_update_thread_id(chat_client_base: SupportsChatGetResponse) -> None: +async def test_chat_client_agent_run_with_session(chat_client_base: SupportsChatGetResponse) -> None: mock_response = ChatResponse( messages=[Message(role="assistant", contents=[Content.from_text("test response")])], conversation_id="123", @@ -145,25 +146,24 @@ async def test_chat_client_agent_update_thread_id(chat_client_base: SupportsChat client=chat_client_base, tools={"type": "code_interpreter"}, ) - thread = agent.get_new_thread() + session = agent.get_session("123") - result = await agent.run("Hello", thread=thread) + result = await agent.run("Hello", session=session) assert result.text == "test response" - assert thread.service_thread_id == "123" + assert session.service_session_id == "123" -async def test_chat_client_agent_update_thread_messages(client: SupportsChatGetResponse) -> None: +async def test_chat_client_agent_update_session_messages(client: SupportsChatGetResponse) -> None: agent = Agent(client=client) - thread = agent.get_new_thread() + session = agent.create_session() - result = await agent.run("Hello", thread=thread) + result = await agent.run("Hello", session=session) assert result.text == "test response" - assert thread.service_thread_id is None - assert thread.message_store is not None + assert session.service_session_id is None - chat_messages: list[Message] = await thread.message_store.list_messages() + chat_messages: list[Message] = session.state.get("memory", {}).get("messages", []) assert chat_messages is not None assert len(chat_messages) == 2 @@ -171,12 +171,12 @@ async def test_chat_client_agent_update_thread_messages(client: SupportsChatGetR assert chat_messages[1].text == "test response" -async def test_chat_client_agent_update_thread_conversation_id_missing(client: SupportsChatGetResponse) -> None: +async def test_chat_client_agent_update_session_conversation_id_missing(client: SupportsChatGetResponse) -> None: agent = Agent(client=client) - thread = AgentThread(service_thread_id="123") + session = AgentSession(service_session_id="123") - with raises(AgentExecutionException, match="Service did not return a valid conversation id"): - await agent._update_thread_with_type_and_conversation_id(thread, None) # type: ignore[reportPrivateUsage] + # With the session-based API, service_session_id is managed directly on the session + assert session.service_session_id == "123" async def test_chat_client_agent_default_author_name(client: SupportsChatGetResponse) -> None: @@ -214,54 +214,41 @@ async def test_chat_client_agent_author_name_is_used_from_response(chat_client_b # Mock context provider for testing -class MockContextProvider(ContextProvider): +class MockContextProvider(BaseContextProvider): def __init__(self, messages: list[Message] | None = None) -> None: + super().__init__(source_id="mock") self.context_messages = messages - self.thread_created_called = False - self.invoked_called = False - self.invoking_called = False - self.thread_created_thread_id = None - self.invoked_thread_id = None + self.before_run_called = False + self.after_run_called = False self.new_messages: list[Message] = [] + self.last_service_session_id: str | None = None + + async def before_run(self, *, agent: Any, session: Any, context: Any, state: Any) -> None: + self.before_run_called = True + if self.context_messages: + context.extend_messages(self, self.context_messages) - async def thread_created(self, thread_id: str | None) -> None: - self.thread_created_called = True - self.thread_created_thread_id = thread_id - - async def invoked( - self, - request_messages: Message | Sequence[Message], - response_messages: Message | Sequence[Message] | None = None, - invoke_exception: Any = None, - **kwargs: Any, - ) -> None: - self.invoked_called = True - if isinstance(request_messages, Message): - self.new_messages.append(request_messages) - else: - self.new_messages.extend(request_messages) - if isinstance(response_messages, Message): - self.new_messages.append(response_messages) - else: - self.new_messages.extend(response_messages) - - async def invoking(self, messages: Message | MutableSequence[Message], **kwargs: Any) -> Context: - self.invoking_called = True - return Context(messages=self.context_messages) - - -async def test_chat_agent_context_providers_model_invoking(client: SupportsChatGetResponse) -> None: - """Test that context providers' invoking is called during agent run.""" + async def after_run(self, *, agent: Any, session: Any, context: Any, state: Any) -> None: + self.after_run_called = True + if session: + self.last_service_session_id = session.service_session_id + if context.response: + self.new_messages.extend(context.input_messages) + self.new_messages.extend(context.response.messages) + + +async def test_chat_agent_context_providers_model_before_run(client: SupportsChatGetResponse) -> None: + """Test that context providers' before_run is called during agent run.""" mock_provider = MockContextProvider(messages=[Message(role="system", text="Test context instructions")]) - agent = Agent(client=client, context_provider=mock_provider) + agent = Agent(client=client, context_providers=[mock_provider]) await agent.run("Hello") - assert mock_provider.invoking_called + assert mock_provider.before_run_called -async def test_chat_agent_context_providers_thread_created(chat_client_base: SupportsChatGetResponse) -> None: - """Test that context providers' thread_created is called during agent run.""" +async def test_chat_agent_context_providers_after_run(chat_client_base: SupportsChatGetResponse) -> None: + """Test that context providers' after_run is called during agent run.""" mock_provider = MockContextProvider() chat_client_base.run_responses = [ ChatResponse( @@ -270,22 +257,23 @@ async def test_chat_agent_context_providers_thread_created(chat_client_base: Sup ) ] - agent = Agent(client=chat_client_base, context_provider=mock_provider) + agent = Agent(client=chat_client_base, context_providers=[mock_provider]) - await agent.run("Hello") + session = agent.get_session("test-thread-id") + await agent.run("Hello", session=session) - assert mock_provider.thread_created_called - assert mock_provider.thread_created_thread_id == "test-thread-id" + assert mock_provider.after_run_called + assert mock_provider.last_service_session_id == "test-thread-id" async def test_chat_agent_context_providers_messages_adding(client: SupportsChatGetResponse) -> None: - """Test that context providers' invoked is called during agent run.""" + """Test that context providers' after_run is called during agent run.""" mock_provider = MockContextProvider() - agent = Agent(client=client, context_provider=mock_provider) + agent = Agent(client=client, context_providers=[mock_provider]) await agent.run("Hello") - assert mock_provider.invoked_called + assert mock_provider.after_run_called # Should be called with both input and response messages assert len(mock_provider.new_messages) >= 2 @@ -293,12 +281,13 @@ async def test_chat_agent_context_providers_messages_adding(client: SupportsChat async def test_chat_agent_context_instructions_in_messages(client: SupportsChatGetResponse) -> None: """Test that AI context instructions are included in messages.""" mock_provider = MockContextProvider(messages=[Message(role="system", text="Context-specific instructions")]) - agent = Agent(client=client, instructions="Agent instructions", context_provider=mock_provider) + agent = Agent(client=client, instructions="Agent instructions", context_providers=[mock_provider]) - # We need to test the _prepare_thread_and_messages method directly - _, _, messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage] - thread=None, input_messages=[Message(role="user", text="Hello")] + # We need to test the _prepare_session_and_messages method directly + session_context, _ = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + session=None, input_messages=[Message(role="user", text="Hello")] ) + messages = session_context.get_messages(include_input=True) # Should have context instructions, and user message assert len(messages) == 2 @@ -312,11 +301,12 @@ async def test_chat_agent_context_instructions_in_messages(client: SupportsChatG async def test_chat_agent_no_context_instructions(client: SupportsChatGetResponse) -> None: """Test behavior when AI context has no instructions.""" mock_provider = MockContextProvider() - agent = Agent(client=client, instructions="Agent instructions", context_provider=mock_provider) + agent = Agent(client=client, instructions="Agent instructions", context_providers=[mock_provider]) - _, _, messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage] - thread=None, input_messages=[Message(role="user", text="Hello")] + session_context, _ = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + session=None, input_messages=[Message(role="user", text="Hello")] ) + messages = session_context.get_messages(include_input=True) # Should have agent instructions and user message only assert len(messages) == 1 @@ -327,7 +317,7 @@ async def test_chat_agent_no_context_instructions(client: SupportsChatGetRespons async def test_chat_agent_run_stream_context_providers(client: SupportsChatGetResponse) -> None: """Test that context providers work with run method.""" mock_provider = MockContextProvider(messages=[Message(role="system", text="Stream context instructions")]) - agent = Agent(client=client, context_provider=mock_provider) + agent = Agent(client=client, context_providers=[mock_provider]) # Collect all stream updates and get final response stream = agent.run("Hello", stream=True) @@ -338,14 +328,12 @@ async def test_chat_agent_run_stream_context_providers(client: SupportsChatGetRe await stream.get_final_response() # Verify context provider was called - assert mock_provider.invoking_called - # no conversation id is created, so no need to thread_create to be called. - assert not mock_provider.thread_created_called - assert mock_provider.invoked_called + assert mock_provider.before_run_called + assert mock_provider.after_run_called -async def test_chat_agent_context_providers_with_thread_service_id(chat_client_base: SupportsChatGetResponse) -> None: - """Test context providers with service-managed thread.""" +async def test_chat_agent_context_providers_with_service_session_id(chat_client_base: SupportsChatGetResponse) -> None: + """Test context providers with service-managed session.""" mock_provider = MockContextProvider() chat_client_base.run_responses = [ ChatResponse( @@ -354,14 +342,14 @@ async def test_chat_agent_context_providers_with_thread_service_id(chat_client_b ) ] - agent = Agent(client=chat_client_base, context_provider=mock_provider) + agent = Agent(client=chat_client_base, context_providers=[mock_provider]) - # Use existing service-managed thread - thread = agent.get_new_thread(service_thread_id="existing-thread-id") - await agent.run("Hello", thread=thread) + # Use existing service-managed session + session = agent.get_session("existing-thread-id") + await agent.run("Hello", session=session) - # invoked should be called with the service thread ID from response - assert mock_provider.invoked_called + # after_run should be called + assert mock_provider.after_run_called # Tests for as_tool method @@ -562,16 +550,16 @@ async def test_chat_agent_with_local_mcp_tools(client: SupportsChatGetResponse) pass -async def test_agent_tool_receives_thread_in_kwargs(chat_client_base: Any) -> None: - """Verify tool execution receives 'thread' inside **kwargs when function is called by client.""" +async def test_agent_tool_receives_session_in_kwargs(chat_client_base: Any) -> None: + """Verify tool execution receives 'session' inside **kwargs when function is called by client.""" captured: dict[str, Any] = {} - @tool(name="echo_thread_info", approval_mode="never_require") - def echo_thread_info(text: str, **kwargs: Any) -> str: # type: ignore[reportUnknownParameterType] - thread = kwargs.get("thread") - captured["has_thread"] = thread is not None - captured["has_message_store"] = thread.message_store is not None if isinstance(thread, AgentThread) else False + @tool(name="echo_session_info", approval_mode="never_require") + def echo_session_info(text: str, **kwargs: Any) -> str: # type: ignore[reportUnknownParameterType] + session = kwargs.get("session") + captured["has_session"] = session is not None + captured["has_state"] = session.state is not None if isinstance(session, AgentSession) else False return f"echo: {text}" # Make the base client emit a function call for our tool @@ -580,21 +568,21 @@ def echo_thread_info(text: str, **kwargs: Any) -> str: # type: ignore[reportUnk messages=Message( role="assistant", contents=[ - Content.from_function_call(call_id="1", name="echo_thread_info", arguments='{"text": "hello"}') + Content.from_function_call(call_id="1", name="echo_session_info", arguments='{"text": "hello"}') ], ) ), ChatResponse(messages=Message(role="assistant", text="done")), ] - agent = Agent(client=chat_client_base, tools=[echo_thread_info], chat_message_store_factory=ChatMessageStore) - thread = agent.get_new_thread() + agent = Agent(client=chat_client_base, tools=[echo_session_info]) + session = agent.create_session() - result = await agent.run("hello", thread=thread, options={"additional_function_arguments": {"thread": thread}}) + result = await agent.run("hello", session=session, options={"additional_function_arguments": {"session": session}}) assert result.text == "done" - assert captured.get("has_thread") is True - assert captured.get("has_message_store") is True + assert captured.get("has_session") is True + assert captured.get("has_state") is True async def test_chat_agent_tool_choice_run_level_overrides_agent_level(chat_client_base: Any, tool_tool: Any) -> None: @@ -801,73 +789,70 @@ def test_sanitize_agent_name_replaces_invalid_chars(): # endregion -# region Test SupportsAgentRun.get_new_thread and deserialize_thread +# region Test SupportsAgentRun.create_session @pytest.mark.asyncio -async def test_agent_get_new_thread(chat_client_base: SupportsChatGetResponse, tool_tool: FunctionTool): - """Test that get_new_thread returns a new AgentThread.""" +async def test_agent_create_session(chat_client_base: SupportsChatGetResponse, tool_tool: FunctionTool): + """Test that create_session returns a new AgentSession.""" agent = Agent(client=chat_client_base, tools=[tool_tool]) - thread = agent.get_new_thread() + session = agent.create_session() - assert thread is not None - assert isinstance(thread, AgentThread) + assert session is not None + assert isinstance(session, AgentSession) @pytest.mark.asyncio -async def test_agent_get_new_thread_with_context_provider( +async def test_agent_create_session_with_context_providers( chat_client_base: SupportsChatGetResponse, tool_tool: FunctionTool ): - """Test that get_new_thread passes context_provider to the thread.""" + """Test that create_session works when context_providers are set on the agent.""" - class TestContextProvider(ContextProvider): - async def invoking(self, messages, **kwargs): - return Context() + class TestContextProvider(BaseContextProvider): + def __init__(self): + super().__init__(source_id="test") provider = TestContextProvider() - agent = Agent(client=chat_client_base, tools=[tool_tool], context_provider=provider) + agent = Agent(client=chat_client_base, tools=[tool_tool], context_providers=[provider]) - thread = agent.get_new_thread() + session = agent.create_session() - assert thread is not None - assert thread.context_provider is provider + assert session is not None + assert agent.context_providers[0] is provider @pytest.mark.asyncio -async def test_agent_get_new_thread_with_service_thread_id( +async def test_agent_get_session_with_service_session_id( chat_client_base: SupportsChatGetResponse, tool_tool: FunctionTool ): - """Test that get_new_thread passes kwargs like service_thread_id to the thread.""" + """Test that get_session creates a session with service_session_id.""" agent = Agent(client=chat_client_base, tools=[tool_tool]) - thread = agent.get_new_thread(service_thread_id="test-thread-123") + session = agent.get_session("test-thread-123") - assert thread is not None - assert thread.service_thread_id == "test-thread-123" + assert session is not None + assert session.service_session_id == "test-thread-123" @pytest.mark.asyncio -async def test_agent_deserialize_thread(chat_client_base: SupportsChatGetResponse, tool_tool: FunctionTool): - """Test deserialize_thread restores a thread from serialized state.""" +async def test_agent_session_from_dict(chat_client_base: SupportsChatGetResponse, tool_tool: FunctionTool): + """Test AgentSession.from_dict restores a session from serialized state.""" agent = Agent(client=chat_client_base, tools=[tool_tool]) - # Create serialized thread state with messages + # Create serialized session state serialized_state = { - "service_thread_id": None, - "chat_message_store_state": { - "messages": [{"role": "user", "text": "Hello"}], - }, + "type": "session", + "session_id": "test-session", + "service_session_id": None, + "state": {}, } - thread = await agent.deserialize_thread(serialized_state) + session = AgentSession.from_dict(serialized_state) - assert thread is not None - assert isinstance(thread, AgentThread) - assert thread.message_store is not None - messages = await thread.message_store.list_messages() - assert len(messages) == 1 - assert messages[0].text == "Hello" + assert session is not None + assert isinstance(session, AgentSession) + assert session.session_id == "test-session" # endregion @@ -876,19 +861,6 @@ async def test_agent_deserialize_thread(chat_client_base: SupportsChatGetRespons # region Test Agent initialization edge cases -@pytest.mark.asyncio -async def test_chat_agent_raises_with_both_conversation_id_and_store(): - """Test Agent raises error with both conversation_id and chat_message_store_factory.""" - mock_client = MagicMock() - mock_store_factory = MagicMock() - - with pytest.raises(AgentInitializationError, match="Cannot specify both"): - Agent( - client=mock_client, - default_options={"conversation_id": "test_id"}, - chat_message_store_factory=mock_store_factory, - ) - def test_chat_agent_calls_update_agent_name_on_client(): """Test that Agent calls _update_agent_name_and_description on client if available.""" @@ -914,19 +886,22 @@ def context_tool(text: str) -> str: """A tool provided by context.""" return text - class ToolContextProvider(ContextProvider): - async def invoking(self, messages, **kwargs): - return Context(tools=[context_tool]) + class ToolContextProvider(BaseContextProvider): + def __init__(self): + super().__init__(source_id="tool-context") + + async def before_run(self, *, agent, session, context, state): + context.extend_tools("tool-context", [context_tool]) provider = ToolContextProvider() - agent = Agent(client=chat_client_base, context_provider=provider) + agent = Agent(client=chat_client_base, context_providers=[provider]) # Agent starts with empty tools list assert agent.default_options.get("tools") == [] # Run the agent and verify context tools are added - _, options, _ = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage] - thread=None, input_messages=[Message(role="user", text="Hello")] + _, options = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + session=None, input_messages=[Message(role="user", text="Hello")] ) # The context tools should now be in the options @@ -940,40 +915,27 @@ async def test_chat_agent_context_provider_adds_instructions_when_agent_has_none ): """Test that context provider instructions are used when agent has no default instructions.""" - class InstructionContextProvider(ContextProvider): - async def invoking(self, messages, **kwargs): - return Context(instructions="Context-provided instructions") + class InstructionContextProvider(BaseContextProvider): + def __init__(self): + super().__init__(source_id="instruction-context") + + async def before_run(self, *, agent, session, context, state): + context.extend_instructions("instruction-context", "Context-provided instructions") provider = InstructionContextProvider() - agent = Agent(client=chat_client_base, context_provider=provider) + agent = Agent(client=chat_client_base, context_providers=[provider]) # Verify agent has no default instructions assert agent.default_options.get("instructions") is None # Run the agent and verify context instructions are available - _, options, _ = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage] - thread=None, input_messages=[Message(role="user", text="Hello")] + _, options = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + session=None, input_messages=[Message(role="user", text="Hello")] ) # The context instructions should now be in the options assert options.get("instructions") == "Context-provided instructions" -@pytest.mark.asyncio -async def test_chat_agent_raises_on_conversation_id_mismatch(chat_client_base: SupportsChatGetResponse): - """Test that Agent raises when thread and agent have different conversation IDs.""" - agent = Agent( - client=chat_client_base, - default_options={"conversation_id": "agent-conversation-id"}, - ) - - # Create a thread with a different service_thread_id - thread = AgentThread(service_thread_id="different-thread-id") - - with pytest.raises(AgentExecutionException, match="conversation_id set on the agent is different"): - await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage] - thread=thread, input_messages=[Message(role="user", text="Hello")] - ) - # endregion diff --git a/python/packages/core/tests/core/test_function_invocation_logic.py b/python/packages/core/tests/core/test_function_invocation_logic.py index e135e2fee6..9e498cae76 100644 --- a/python/packages/core/tests/core/test_function_invocation_logic.py +++ b/python/packages/core/tests/core/test_function_invocation_logic.py @@ -168,8 +168,8 @@ def ai_func(user_query: str) -> str: agent = Agent(client=chat_client_base, tools=[ai_func]) async def handler(request: web.Request) -> web.Response: - thread = agent.get_new_thread() - result = await agent.run("Fix issue", thread=thread) + session = agent.create_session() + result = await agent.run("Fix issue", session=session) return web.Response(text=result.text or "") app = web.Application() @@ -230,8 +230,8 @@ def ai_func(user_query: str) -> str: async def init_app() -> web.Application: async def handler(request: web.Request) -> web.Response: - thread = agent.get_new_thread() - result = await agent.run("Fix issue", thread=thread) + session = agent.create_session() + result = await agent.run("Fix issue", session=session) return web.Response(text=result.text or "") app = web.Application() diff --git a/python/packages/core/tests/core/test_middleware.py b/python/packages/core/tests/core/test_middleware.py index f37c855ba3..4ac4f22f1c 100644 --- a/python/packages/core/tests/core/test_middleware.py +++ b/python/packages/core/tests/core/test_middleware.py @@ -56,17 +56,17 @@ def test_init_with_custom_values(self, mock_agent: SupportsAgentRun) -> None: assert context.stream is True assert context.metadata == metadata - def test_init_with_thread(self, mock_agent: SupportsAgentRun) -> None: - """Test AgentContext initialization with thread parameter.""" - from agent_framework import AgentThread + def test_init_with_session(self, mock_agent: SupportsAgentRun) -> None: + """Test AgentContext initialization with session parameter.""" + from agent_framework import AgentSession messages = [Message(role="user", text="test")] - thread = AgentThread() - context = AgentContext(agent=mock_agent, messages=messages, thread=thread) + session = AgentSession() + context = AgentContext(agent=mock_agent, messages=messages, session=session) assert context.agent is mock_agent assert context.messages == messages - assert context.thread is thread + assert context.session is session assert context.stream is False assert context.metadata == {} @@ -356,23 +356,23 @@ async def _stream() -> AsyncIterable[AgentResponseUpdate]: assert updates[1].text == "chunk2" assert execution_order == ["handler_start", "handler_end"] - async def test_execute_with_thread_in_context(self, mock_agent: SupportsAgentRun) -> None: - """Test pipeline execution properly passes thread to middleware.""" - from agent_framework import AgentThread + async def test_execute_with_session_in_context(self, mock_agent: SupportsAgentRun) -> None: + """Test pipeline execution properly passes session to middleware.""" + from agent_framework import AgentSession - captured_thread = None + captured_session = None - class ThreadCapturingMiddleware(AgentMiddleware): + class SessionCapturingMiddleware(AgentMiddleware): async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: - nonlocal captured_thread - captured_thread = context.thread + nonlocal captured_session + captured_session = context.session await call_next() - middleware = ThreadCapturingMiddleware() + middleware = SessionCapturingMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [Message(role="user", text="test")] - thread = AgentThread() - context = AgentContext(agent=mock_agent, messages=messages, thread=thread) + session = AgentSession() + context = AgentContext(agent=mock_agent, messages=messages, session=session) expected_response = AgentResponse(messages=[Message(role="assistant", text="response")]) @@ -381,22 +381,22 @@ async def final_handler(ctx: AgentContext) -> AgentResponse: result = await pipeline.execute(context, final_handler) assert result == expected_response - assert captured_thread is thread + assert captured_session is session - async def test_execute_with_no_thread_in_context(self, mock_agent: SupportsAgentRun) -> None: - """Test pipeline execution when no thread is provided.""" - captured_thread = "not_none" # Use string to distinguish from None + async def test_execute_with_no_session_in_context(self, mock_agent: SupportsAgentRun) -> None: + """Test pipeline execution when no session is provided.""" + captured_session = "not_none" # Use string to distinguish from None - class ThreadCapturingMiddleware(AgentMiddleware): + class SessionCapturingMiddleware(AgentMiddleware): async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: - nonlocal captured_thread - captured_thread = context.thread + nonlocal captured_session + captured_session = context.session await call_next() - middleware = ThreadCapturingMiddleware() + middleware = SessionCapturingMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [Message(role="user", text="test")] - context = AgentContext(agent=mock_agent, messages=messages, thread=None) + context = AgentContext(agent=mock_agent, messages=messages, session=None) expected_response = AgentResponse(messages=[Message(role="assistant", text="response")]) @@ -405,7 +405,7 @@ async def final_handler(ctx: AgentContext) -> AgentResponse: result = await pipeline.execute(context, final_handler) assert result == expected_response - assert captured_thread is None + assert captured_session is None class TestFunctionMiddlewarePipeline: diff --git a/python/packages/core/tests/core/test_middleware_with_agent.py b/python/packages/core/tests/core/test_middleware_with_agent.py index 597ca12dbd..bfdf0def1c 100644 --- a/python/packages/core/tests/core/test_middleware_with_agent.py +++ b/python/packages/core/tests/core/test_middleware_with_agent.py @@ -1405,19 +1405,19 @@ async def test_function_middleware(context: Any, call_next: Any) -> None: assert test_function_middleware._middleware_type == MiddlewareType.FUNCTION # type: ignore[attr-defined] -class TestChatAgentThreadBehavior: - """Test cases for thread behavior in AgentContext across multiple runs.""" +class TestChatAgentSessionBehavior: + """Test cases for session behavior in AgentContext across multiple runs.""" - async def test_agent_context_thread_behavior_across_multiple_runs(self, client: "MockChatClient") -> None: - """Test that AgentContext.thread property behaves correctly across multiple agent runs.""" + async def test_agent_context_session_behavior_across_multiple_runs(self, client: "MockChatClient") -> None: + """Test that AgentContext.session property behaves correctly across multiple agent runs.""" thread_states: list[dict[str, Any]] = [] - class ThreadTrackingMiddleware(AgentMiddleware): + class SessionTrackingMiddleware(AgentMiddleware): async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Capture state before next() call thread_messages = [] - if context.thread and context.thread.message_store: - thread_messages = await context.thread.message_store.list_messages() + if context.session and context.session.state.get("memory"): + thread_messages = context.session.state.get("memory", {}).get("messages", []) before_state = { "before_next": True, @@ -1432,8 +1432,8 @@ async def process(self, context: AgentContext, call_next: Callable[[], Awaitable # Capture state after next() call thread_messages_after = [] - if context.thread and context.thread.message_store: - thread_messages_after = await context.thread.message_store.list_messages() + if context.session and context.session.state.get("memory"): + thread_messages_after = context.session.state.get("memory", {}).get("messages", []) after_state = { "before_next": False, @@ -1444,19 +1444,16 @@ async def process(self, context: AgentContext, call_next: Callable[[], Awaitable } thread_states.append(after_state) - # Import the ChatMessageStore to configure the agent with a message store factory - from agent_framework import ChatMessageStore - - # Create Agent with thread tracking middleware and a message store factory - middleware = ThreadTrackingMiddleware() - agent = Agent(client=client, middleware=[middleware], chat_message_store_factory=ChatMessageStore) + # Create Agent with session tracking middleware + middleware = SessionTrackingMiddleware() + agent = Agent(client=client, middleware=[middleware]) - # Create a thread that will persist messages between runs - thread = agent.get_new_thread() + # Create a session that will persist messages between runs + session = agent.create_session() # First run first_messages = [Message(role="user", text="first message")] - first_response = await agent.run(first_messages, thread=thread) + first_response = await agent.run(first_messages, session=session) # Verify first response assert first_response is not None @@ -1464,7 +1461,7 @@ async def process(self, context: AgentContext, call_next: Callable[[], Awaitable # Second run - use the same thread second_messages = [Message(role="user", text="second message")] - second_response = await agent.run(second_messages, thread=thread) + second_response = await agent.run(second_messages, session=session) # Verify second response assert second_response is not None diff --git a/python/packages/core/tests/core/test_observability.py b/python/packages/core/tests/core/test_observability.py index 77b88a873e..7ffb9f3a6b 100644 --- a/python/packages/core/tests/core/test_observability.py +++ b/python/packages/core/tests/core/test_observability.py @@ -441,19 +441,19 @@ def __init__(self): self.description = "Test agent description" self.default_options: dict[str, Any] = {"model_id": "TestModel"} - def run(self, messages=None, *, thread=None, stream=False, **kwargs): + def run(self, messages=None, *, session=None, stream=False, **kwargs): if stream: return self._run_stream_impl(messages=messages, **kwargs) return self._run_impl(messages=messages, **kwargs) - async def _run_impl(self, messages=None, *, thread=None, **kwargs): + async def _run_impl(self, messages=None, *, session=None, **kwargs): return AgentResponse( messages=[Message("assistant", ["Agent response"])], usage_details=UsageDetails(input_token_count=15, output_token_count=25), response_id="test_response_id", ) - async def _run_stream_impl(self, messages=None, *, thread=None, **kwargs): + async def _run_stream_impl(self, messages=None, *, session=None, **kwargs): from agent_framework import AgentResponse, AgentResponseUpdate, ResponseStream async def _stream(): @@ -1572,12 +1572,12 @@ async def run( messages=None, *, stream: bool = False, - thread=None, + session=None, **kwargs, ): if stream: return ResponseStream( - self._run_stream(messages=messages, thread=thread), + self._run_stream(messages=messages, session=session), finalizer=lambda x: AgentResponse.from_updates(x), ) return AgentResponse(messages=[Message("assistant", ["Test response"])]) @@ -1586,7 +1586,7 @@ async def _run_stream( self, messages=None, *, - thread=None, + session=None, **kwargs, ): from agent_framework import AgentResponseUpdate @@ -1635,7 +1635,7 @@ def description(self): def default_options(self): return self._default_options - async def run(self, messages=None, *, stream: bool = False, thread=None, **kwargs): + async def run(self, messages=None, *, stream: bool = False, session=None, **kwargs): raise RuntimeError("Agent failed") class FailingAgent(AgentTelemetryLayer, _FailingAgent): @@ -1685,15 +1685,15 @@ def description(self): def default_options(self): return self._default_options - def run(self, messages=None, *, stream=False, thread=None, **kwargs): + def run(self, messages=None, *, stream=False, session=None, **kwargs): if stream: return self._run_stream_impl(messages=messages, **kwargs) return self._run_impl(messages=messages, **kwargs) - async def _run_impl(self, messages=None, *, thread=None, **kwargs): + async def _run_impl(self, messages=None, *, session=None, **kwargs): return AgentResponse(messages=[Message("assistant", ["Test"])]) - def _run_stream_impl(self, messages=None, *, thread=None, **kwargs): + def _run_stream_impl(self, messages=None, *, session=None, **kwargs): async def _stream(): yield AgentResponseUpdate(contents=[Content.from_text("Hello ")], role="assistant") yield AgentResponseUpdate(contents=[Content.from_text("World")], role="assistant") @@ -1822,15 +1822,15 @@ def description(self): def default_options(self): return self._default_options - def run(self, messages=None, *, stream=False, thread=None, **kwargs): + def run(self, messages=None, *, stream=False, session=None, **kwargs): if stream: return self._run_stream_impl(messages=messages, **kwargs) return self._run_impl(messages=messages, **kwargs) - async def _run_impl(self, messages=None, *, thread=None, **kwargs): + async def _run_impl(self, messages=None, *, session=None, **kwargs): return AgentResponse(messages=[]) - def _run_stream_impl(self, messages=None, *, thread=None, **kwargs): + def _run_stream_impl(self, messages=None, *, session=None, **kwargs): async def _stream(): yield AgentResponseUpdate(contents=[Content.from_text("Starting")], role="assistant") raise RuntimeError("Stream failed") @@ -1919,7 +1919,7 @@ def description(self): def default_options(self): return self._default_options - async def run(self, messages=None, *, stream: bool = False, thread=None, **kwargs): + async def run(self, messages=None, *, stream: bool = False, session=None, **kwargs): if stream: return ResponseStream( self._run_stream(messages=messages, **kwargs), @@ -1927,7 +1927,7 @@ async def run(self, messages=None, *, stream: bool = False, thread=None, **kwarg ) return AgentResponse(messages=[]) - async def _run_stream(self, messages=None, *, thread=None, **kwargs): + async def _run_stream(self, messages=None, *, session=None, **kwargs): from agent_framework import AgentResponseUpdate yield AgentResponseUpdate(contents=[Content.from_text("test")], role="assistant") @@ -1974,15 +1974,15 @@ def description(self): def default_options(self): return self._default_options - def run(self, messages=None, *, stream=False, thread=None, **kwargs): + def run(self, messages=None, *, stream=False, session=None, **kwargs): if stream: return self._run_stream(messages=messages, **kwargs) return self._run(messages=messages, **kwargs) - async def _run(self, messages=None, *, thread=None, **kwargs): + async def _run(self, messages=None, *, session=None, **kwargs): return AgentResponse(messages=[]) - async def _run_stream(self, messages=None, *, thread=None, **kwargs): + async def _run_stream(self, messages=None, *, session=None, **kwargs): yield AgentResponseUpdate(contents=[Content.from_text("test")], role="assistant") class TestAgent(AgentTelemetryLayer, _TestAgent): diff --git a/python/packages/core/tests/openai/test_openai_assistants_client.py b/python/packages/core/tests/openai/test_openai_assistants_client.py index c50b026cb7..3e4ac27131 100644 --- a/python/packages/core/tests/openai/test_openai_assistants_client.py +++ b/python/packages/core/tests/openai/test_openai_assistants_client.py @@ -14,7 +14,7 @@ Agent, AgentResponse, AgentResponseUpdate, - AgentThread, + AgentSession, ChatResponse, ChatResponseUpdate, Content, @@ -1270,25 +1270,25 @@ async def test_openai_assistants_agent_thread_persistence(): client=OpenAIAssistantsClient(model_id=INTEGRATION_TEST_MODEL), instructions="You are a helpful assistant with good memory.", ) as agent: - # Create a new thread that will be reused - thread = agent.get_new_thread() + # Create a new session that will be reused + session = agent.create_session() # First message - establish context first_response = await agent.run( - "Remember this number: 42. What number did I just tell you to remember?", thread=thread + "Remember this number: 42. What number did I just tell you to remember?", session=session ) assert isinstance(first_response, AgentResponse) assert "42" in first_response.text # Second message - test conversation memory second_response = await agent.run( - "What number did I tell you to remember in my previous message?", thread=thread + "What number did I tell you to remember in my previous message?", session=session ) assert isinstance(second_response, AgentResponse) assert "42" in second_response.text - # Verify thread has been populated with conversation ID - assert thread.service_thread_id is not None + # Verify session has been populated with conversation ID + assert session.service_session_id is not None @pytest.mark.flaky @@ -1303,17 +1303,17 @@ async def test_openai_assistants_agent_existing_thread_id(): instructions="You are a helpful weather agent.", tools=[get_weather], ) as agent: - # Start a conversation and get the thread ID - thread = agent.get_new_thread() - response1 = await agent.run("What's the weather in Paris?", thread=thread) + # Start a conversation and get the session ID + session = agent.create_session() + response1 = await agent.run("What's the weather in Paris?", session=session) # Validate first response assert isinstance(response1, AgentResponse) assert response1.text is not None assert any(word in response1.text.lower() for word in ["weather", "paris"]) - # The thread ID is set after the first response - existing_thread_id = thread.service_thread_id + # The session ID is set after the first response + existing_thread_id = session.service_session_id assert existing_thread_id is not None # Now continue with the same thread ID in a new agent instance @@ -1323,11 +1323,11 @@ async def test_openai_assistants_agent_existing_thread_id(): instructions="You are a helpful weather agent.", tools=[get_weather], ) as agent: - # Create a thread with the existing ID - thread = AgentThread(service_thread_id=existing_thread_id) + # Create a session with the existing ID + session = AgentSession(service_session_id=existing_thread_id) # Ask about the previous conversation - response2 = await agent.run("What was the last city I asked about?", thread=thread) + response2 = await agent.run("What was the last city I asked about?", session=session) # Validate that the agent remembers the previous conversation assert isinstance(response2, AgentResponse) diff --git a/python/packages/core/tests/workflow/test_agent_executor.py b/python/packages/core/tests/workflow/test_agent_executor.py index b4f431fd84..4dadbdfb11 100644 --- a/python/packages/core/tests/workflow/test_agent_executor.py +++ b/python/packages/core/tests/workflow/test_agent_executor.py @@ -7,9 +7,8 @@ AgentExecutor, AgentResponse, AgentResponseUpdate, - AgentThread, + AgentSession, BaseAgent, - ChatMessageStore, Content, Message, ResponseStream, @@ -21,7 +20,7 @@ class _CountingAgent(BaseAgent): - """Agent that echoes messages with a counter to verify thread state persistence.""" + """Agent that echoes messages with a counter to verify session state persistence.""" def __init__(self, **kwargs: Any): super().__init__(**kwargs) @@ -32,7 +31,7 @@ def run( messages: str | Message | list[str] | list[Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: self.call_count += 1 @@ -52,22 +51,22 @@ async def _run() -> AgentResponse: async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: - """Test that workflow checkpoint stores AgentExecutor's cache and thread states and restores them correctly.""" + """Test that workflow checkpoint stores AgentExecutor's cache and session states and restores them correctly.""" storage = InMemoryCheckpointStorage() - # Create initial agent with a custom thread that has a message store + # Create initial agent with a custom session initial_agent = _CountingAgent(id="test_agent", name="TestAgent") - initial_thread = AgentThread(message_store=ChatMessageStore()) + initial_session = AgentSession() - # Add some initial messages to the thread to verify thread state persistence + # Add some initial messages to the session state to verify session state persistence initial_messages = [ Message(role="user", text="Initial message 1"), Message(role="assistant", text="Initial response 1"), ] - await initial_thread.on_new_messages(initial_messages) + initial_session.state["history"] = {"messages": initial_messages} - # Create AgentExecutor with the thread - executor = AgentExecutor(initial_agent, agent_thread=initial_thread) + # Create AgentExecutor with the session + executor = AgentExecutor(initial_agent, session=initial_session) # Build workflow with checkpointing enabled wf = SequentialBuilder(participants=[executor], checkpoint_storage=storage).build() @@ -90,12 +89,7 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: "and the second one is after the agent execution." ) - # Get the second checkpoint which should contain the state after processing - # the first message by the start executor in the sequential workflow - checkpoints.sort(key=lambda cp: cp.timestamp) - restore_checkpoint = checkpoints[1] - - # Verify checkpoint contains executor state with both cache and thread + # Verify checkpoint contains executor state with both cache and session assert "_executor_state" in restore_checkpoint.state executor_states = restore_checkpoint.state["_executor_state"] assert isinstance(executor_states, dict) @@ -103,13 +97,12 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: executor_state = executor_states[executor.id] # type: ignore[index] assert "cache" in executor_state, "Checkpoint should store executor cache state" - assert "agent_thread" in executor_state, "Checkpoint should store executor thread state" + assert "agent_session" in executor_state, "Checkpoint should store executor session state" - # Verify thread state includes message store - thread_state = executor_state["agent_thread"] # type: ignore[index] - assert "chat_message_store_state" in thread_state, "Thread state should include message store" - chat_store_state = thread_state["chat_message_store_state"] # type: ignore[index] - assert "messages" in chat_store_state, "Message store state should include messages" + # Verify session state structure + session_state = executor_state["agent_session"] # type: ignore[index] + assert "session_id" in session_state, "Session state should include session_id" + assert "state" in session_state, "Session state should include state dict" # Verify checkpoint contains pending requests from agents and responses to be sent assert "pending_agent_requests" in executor_state @@ -118,8 +111,8 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: # Create a new agent and executor for restoration # This simulates starting from a fresh state and restoring from checkpoint restored_agent = _CountingAgent(id="test_agent", name="TestAgent") - restored_thread = AgentThread(message_store=ChatMessageStore()) - restored_executor = AgentExecutor(restored_agent, agent_thread=restored_thread) + restored_session = AgentSession() + restored_executor = AgentExecutor(restored_agent, session=restored_session) # Verify the restored agent starts with a fresh state assert restored_agent.call_count == 0 @@ -140,39 +133,27 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: assert resumed_output is not None - # Verify the restored executor's state matches the original - # The cache should be restored (though it may be cleared after processing) - # The thread should have all messages including those from the initial state - message_store = restored_executor._agent_thread.message_store # type: ignore[reportPrivateUsage] - assert message_store is not None - thread_messages = await message_store.list_messages() - - # Thread should contain: - # 1. Initial messages from before the checkpoint (2 messages) - # 2. User message from first run (1 message) - # 3. Assistant response from first run (1 message) - assert len(thread_messages) >= 2, "Thread should preserve initial messages from before checkpoint" - - # Verify initial messages are preserved - assert thread_messages[0].text == "Initial message 1" - assert thread_messages[1].text == "Initial response 1" + # Verify the restored executor's session state was restored + restored_session_obj = restored_executor._session # type: ignore[reportPrivateUsage] + assert restored_session_obj is not None + assert restored_session_obj.session_id == initial_session.session_id async def test_agent_executor_save_and_restore_state_directly() -> None: """Test AgentExecutor's on_checkpoint_save and on_checkpoint_restore methods directly.""" - # Create agent with thread containing messages + # Create agent with session containing state agent = _CountingAgent(id="direct_test_agent", name="DirectTestAgent") - thread = AgentThread(message_store=ChatMessageStore()) + session = AgentSession() - # Add messages to thread - thread_messages = [ - Message(role="user", text="Message in thread 1"), - Message(role="assistant", text="Thread response 1"), - Message(role="user", text="Message in thread 2"), + # Add messages to session state + session_messages = [ + Message(role="user", text="Message in session 1"), + Message(role="assistant", text="Session response 1"), + Message(role="user", text="Message in session 2"), ] - await thread.on_new_messages(thread_messages) + session.state["history"] = {"messages": session_messages} - executor = AgentExecutor(agent, agent_thread=thread) + executor = AgentExecutor(agent, session=session) # Add messages to executor cache cache_messages = [ @@ -184,26 +165,23 @@ async def test_agent_executor_save_and_restore_state_directly() -> None: # Snapshot the state state = await executor.on_checkpoint_save() - # Verify snapshot contains both cache and thread + # Verify snapshot contains both cache and session assert "cache" in state - assert "agent_thread" in state + assert "agent_session" in state - # Verify thread state structure - thread_state = state["agent_thread"] # type: ignore[index] - assert "chat_message_store_state" in thread_state - assert "messages" in thread_state["chat_message_store_state"] + # Verify session state structure + session_state = state["agent_session"] # type: ignore[index] + assert "session_id" in session_state + assert "state" in session_state # Create new executor to restore into new_agent = _CountingAgent(id="direct_test_agent", name="DirectTestAgent") - new_thread = AgentThread(message_store=ChatMessageStore()) - new_executor = AgentExecutor(new_agent, agent_thread=new_thread) + new_session = AgentSession() + new_executor = AgentExecutor(new_agent, session=new_session) # Verify new executor starts empty assert len(new_executor._cache) == 0 # type: ignore[reportPrivateUsage] - initial_message_store = new_thread.message_store - assert initial_message_store is not None - initial_thread_msgs = await initial_message_store.list_messages() - assert len(initial_thread_msgs) == 0 + assert len(new_session.state) == 0 # Restore state await new_executor.on_checkpoint_restore(state) @@ -214,11 +192,6 @@ async def test_agent_executor_save_and_restore_state_directly() -> None: assert restored_cache[0].text == "Cached user message" assert restored_cache[1].text == "Cached assistant response" - # Verify thread messages are restored - restored_message_store = new_executor._agent_thread.message_store # type: ignore[reportPrivateUsage] - assert restored_message_store is not None - restored_thread_msgs = await restored_message_store.list_messages() - assert len(restored_thread_msgs) == len(thread_messages) - assert restored_thread_msgs[0].text == "Message in thread 1" - assert restored_thread_msgs[1].text == "Thread response 1" - assert restored_thread_msgs[2].text == "Message in thread 2" + # Verify session was restored with correct session_id + restored_session = new_executor._session # type: ignore[reportPrivateUsage] + assert restored_session.session_id == session.session_id diff --git a/python/packages/core/tests/workflow/test_agent_executor_tool_calls.py b/python/packages/core/tests/workflow/test_agent_executor_tool_calls.py index 3bb51d2224..47356638a6 100644 --- a/python/packages/core/tests/workflow/test_agent_executor_tool_calls.py +++ b/python/packages/core/tests/workflow/test_agent_executor_tool_calls.py @@ -13,7 +13,7 @@ AgentExecutorResponse, AgentResponse, AgentResponseUpdate, - AgentThread, + AgentSession, BaseAgent, ChatResponse, ChatResponseUpdate, @@ -42,7 +42,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: if stream: diff --git a/python/packages/core/tests/workflow/test_agent_utils.py b/python/packages/core/tests/workflow/test_agent_utils.py index 8a8beae5b1..d3889b4d3b 100644 --- a/python/packages/core/tests/workflow/test_agent_utils.py +++ b/python/packages/core/tests/workflow/test_agent_utils.py @@ -3,7 +3,7 @@ from collections.abc import AsyncIterable from typing import Any -from agent_framework import AgentResponse, AgentResponseUpdate, AgentThread, Message +from agent_framework import AgentResponse, AgentResponseUpdate, AgentSession, Message from agent_framework._workflows._agent_utils import resolve_agent_id @@ -37,12 +37,12 @@ def run( messages: str | Message | list[str] | list[Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> AgentResponse | AsyncIterable[AgentResponseUpdate]: ... - def get_new_thread(self, **kwargs: Any) -> AgentThread: - """Creates a new conversation thread for the agent.""" + def create_session(self, **kwargs: Any) -> AgentSession: + """Creates a new conversation session for the agent.""" ... diff --git a/python/packages/core/tests/workflow/test_full_conversation.py b/python/packages/core/tests/workflow/test_full_conversation.py index 41003f6544..80a10347b6 100644 --- a/python/packages/core/tests/workflow/test_full_conversation.py +++ b/python/packages/core/tests/workflow/test_full_conversation.py @@ -12,7 +12,7 @@ AgentExecutorResponse, AgentResponse, AgentResponseUpdate, - AgentThread, + AgentSession, BaseAgent, Content, Executor, @@ -38,7 +38,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: if stream: @@ -108,7 +108,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: # Normalize and record messages for verification diff --git a/python/packages/core/tests/workflow/test_workflow.py b/python/packages/core/tests/workflow/test_workflow.py index 744ad827ea..fb90df6b39 100644 --- a/python/packages/core/tests/workflow/test_workflow.py +++ b/python/packages/core/tests/workflow/test_workflow.py @@ -13,7 +13,7 @@ AgentExecutor, AgentResponse, AgentResponseUpdate, - AgentThread, + AgentSession, BaseAgent, Content, Executor, @@ -838,7 +838,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: if stream: diff --git a/python/packages/core/tests/workflow/test_workflow_agent.py b/python/packages/core/tests/workflow/test_workflow_agent.py index 2a1532502b..8c1066aa42 100644 --- a/python/packages/core/tests/workflow/test_workflow_agent.py +++ b/python/packages/core/tests/workflow/test_workflow_agent.py @@ -11,8 +11,7 @@ AgentExecutorRequest, AgentResponse, AgentResponseUpdate, - AgentThread, - ChatMessageStore, + AgentSession, Content, Executor, Message, @@ -512,79 +511,52 @@ async def list_yielding_executor(messages: list[Message], ctx: WorkflowContext[N assert texts == ["first message", "second message", "third fourth"] async def test_thread_conversation_history_included_in_workflow_run(self) -> None: - """Test that conversation history from thread is included when running WorkflowAgent. - - This verifies that when a thread with existing messages is provided to agent.run(), - the workflow receives the complete conversation history (thread history + new messages). - """ + """Test that messages provided to agent.run() are passed through to the workflow.""" # Create an executor that captures all received messages capturing_executor = ConversationHistoryCapturingExecutor(id="capturing", streaming=False) workflow = WorkflowBuilder(start_executor=capturing_executor).build() agent = WorkflowAgent(workflow=workflow, name="Thread History Test Agent") - # Create a thread with existing conversation history - history_messages = [ - Message(role="user", text="Previous user message"), - Message(role="assistant", text="Previous assistant response"), - ] - message_store = ChatMessageStore(messages=history_messages) - thread = AgentThread(message_store=message_store) + # Create a session + session = AgentSession() - # Run the agent with the thread and a new message + # Run the agent with the session and a new message new_message = "New user question" - await agent.run(new_message, thread=thread) + await agent.run(new_message, session=session) - # Verify the executor received both history AND new message - assert len(capturing_executor.received_messages) == 3 - - # Verify the order: history first, then new message - assert capturing_executor.received_messages[0].text == "Previous user message" - assert capturing_executor.received_messages[1].text == "Previous assistant response" - assert capturing_executor.received_messages[2].text == "New user question" + # Verify the executor received the message + assert len(capturing_executor.received_messages) == 1 + assert capturing_executor.received_messages[0].text == "New user question" async def test_thread_conversation_history_included_in_workflow_stream(self) -> None: - """Test that conversation history from thread is included when streaming WorkflowAgent. - - This verifies that stream=True also includes thread history. - """ + """Test that messages provided to agent.run() are passed through when streaming WorkflowAgent.""" # Create an executor that captures all received messages capturing_executor = ConversationHistoryCapturingExecutor(id="capturing_stream") workflow = WorkflowBuilder(start_executor=capturing_executor).build() agent = WorkflowAgent(workflow=workflow, name="Thread Stream Test Agent") - # Create a thread with existing conversation history - history_messages = [ - Message(role="system", text="You are a helpful assistant"), - Message(role="user", text="Hello"), - Message("assistant", ["Hi there!"]), - ] - message_store = ChatMessageStore(messages=history_messages) - thread = AgentThread(message_store=message_store) + # Create a session + session = AgentSession() - # Stream from the agent with the thread and a new message - async for _ in agent.run("How are you?", stream=True, thread=thread): + # Stream from the agent with the session and a new message + async for _ in agent.run("How are you?", stream=True, session=session): pass - # Verify the executor received all messages (3 from history + 1 new) - assert len(capturing_executor.received_messages) == 4 - - # Verify the order - assert capturing_executor.received_messages[0].text == "You are a helpful assistant" - assert capturing_executor.received_messages[1].text == "Hello" - assert capturing_executor.received_messages[2].text == "Hi there!" - assert capturing_executor.received_messages[3].text == "How are you?" + # Verify the executor received the message + assert len(capturing_executor.received_messages) == 1 + assert capturing_executor.received_messages[0].text == "How are you?" async def test_empty_thread_works_correctly(self) -> None: - """Test that an empty thread (no message store) works correctly.""" + """Test that an empty session (no message store) works correctly.""" capturing_executor = ConversationHistoryCapturingExecutor(id="empty_thread_test") workflow = WorkflowBuilder(start_executor=capturing_executor).build() agent = WorkflowAgent(workflow=workflow, name="Empty Thread Test Agent") - # Create an empty thread - thread = AgentThread() + # Create an empty session + session = AgentSession() - # Run with the empty thread - await agent.run("Just a new message", thread=thread) + # Run with the empty session + await agent.run("Just a new message", session=session) # Should only receive the new message assert len(capturing_executor.received_messages) == 1 @@ -622,27 +594,27 @@ def __init__(self, name: str, response_text: str) -> None: self.description: str | None = None self._response_text = response_text - def get_new_thread(self, **kwargs: Any) -> AgentThread: - return AgentThread() + def create_session(self, **kwargs: Any) -> AgentSession: + return AgentSession() def run( self, messages: str | Content | Message | Sequence[str | Content | Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: if stream: - return self._run_stream(messages=messages, thread=thread, **kwargs) - return self._run(messages=messages, thread=thread, **kwargs) + return self._run_stream(messages=messages, session=session, **kwargs) + return self._run(messages=messages, session=session, **kwargs) async def _run( self, messages: str | Content | Message | Sequence[str | Content | Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> AgentResponse: @@ -654,7 +626,7 @@ def _run_stream( self, messages: str | Content | Message | Sequence[str | Content | Message] | None = None, *, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: async def _iter(): @@ -710,27 +682,27 @@ def __init__(self, name: str, response_text: str) -> None: self.description: str | None = None self._response_text = response_text - def get_new_thread(self, **kwargs: Any) -> AgentThread: - return AgentThread() + def create_session(self, **kwargs: Any) -> AgentSession: + return AgentSession() def run( self, messages: str | Content | Message | Sequence[str | Content | Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: if stream: - return self._run_stream(messages=messages, thread=thread, **kwargs) - return self._run(messages=messages, thread=thread, **kwargs) + return self._run_stream(messages=messages, session=session, **kwargs) + return self._run(messages=messages, session=session, **kwargs) async def _run( self, messages: str | Content | Message | Sequence[str | Content | Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> AgentResponse: @@ -742,7 +714,7 @@ def _run_stream( self, messages: str | Content | Message | Sequence[str | Content | Message] | None = None, *, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: async def _iter(): @@ -1037,7 +1009,7 @@ def test_merge_updates_metadata_aggregation(self): def test_merge_updates_function_result_ordering_github_2977(self): """Test that FunctionResultContent updates are placed after their FunctionCallContent. - This test reproduces GitHub issue #2977: When using a thread with WorkflowAgent, + This test reproduces GitHub issue #2977: When using a session with WorkflowAgent, FunctionResultContent updates without response_id were being added to global_dangling and placed at the end of messages. This caused OpenAI to reject the conversation because "An assistant message with 'tool_calls' must be followed by tool messages responding diff --git a/python/packages/core/tests/workflow/test_workflow_builder.py b/python/packages/core/tests/workflow/test_workflow_builder.py index fd0d74586a..073a24e5a3 100644 --- a/python/packages/core/tests/workflow/test_workflow_builder.py +++ b/python/packages/core/tests/workflow/test_workflow_builder.py @@ -9,7 +9,7 @@ AgentExecutor, AgentResponse, AgentResponseUpdate, - AgentThread, + AgentSession, BaseAgent, Executor, Message, @@ -21,7 +21,7 @@ class DummyAgent(BaseAgent): - def run(self, messages=None, *, stream: bool = False, thread: AgentThread | None = None, **kwargs): # type: ignore[override] + def run(self, messages=None, *, stream: bool = False, session: AgentSession | None = None, **kwargs): # type: ignore[override] if stream: return self._run_stream_impl() return self._run_impl(messages) diff --git a/python/packages/core/tests/workflow/test_workflow_kwargs.py b/python/packages/core/tests/workflow/test_workflow_kwargs.py index 9c21652281..5a71afafe4 100644 --- a/python/packages/core/tests/workflow/test_workflow_kwargs.py +++ b/python/packages/core/tests/workflow/test_workflow_kwargs.py @@ -8,7 +8,7 @@ from agent_framework import ( AgentResponse, AgentResponseUpdate, - AgentThread, + AgentSession, BaseAgent, Content, Message, @@ -55,7 +55,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: self.captured_kwargs.append(dict(kwargs)) diff --git a/python/packages/devui/tests/devui/conftest.py b/python/packages/devui/tests/devui/conftest.py index 14f48617db..3ff5f499a7 100644 --- a/python/packages/devui/tests/devui/conftest.py +++ b/python/packages/devui/tests/devui/conftest.py @@ -20,7 +20,7 @@ Agent, AgentResponse, AgentResponseUpdate, - AgentThread, + AgentSession, BaseAgent, BaseChatClient, ChatResponse, @@ -162,19 +162,19 @@ def run( messages: str | Message | list[str] | list[Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: self.call_count += 1 if stream: - return self._run_stream(messages=messages, thread=thread, **kwargs) - return self._run(messages=messages, thread=thread, **kwargs) + return self._run_stream(messages=messages, session=session, **kwargs) + return self._run(messages=messages, session=session, **kwargs) async def _run( self, messages: str | Message | list[str] | list[Message] | None = None, *, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> AgentResponse: self.call_count += 1 @@ -184,7 +184,7 @@ def _run_stream( self, messages: str | Message | list[str] | list[Message] | None = None, *, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: self.call_count += 1 @@ -208,19 +208,19 @@ def run( messages: str | Message | list[str] | list[Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: self.call_count += 1 if stream: - return self._run_stream(messages=messages, thread=thread, **kwargs) - return self._run(messages=messages, thread=thread, **kwargs) + return self._run_stream(messages=messages, session=session, **kwargs) + return self._run(messages=messages, session=session, **kwargs) async def _run( self, messages: str | Message | list[str] | list[Message] | None = None, *, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> AgentResponse: return AgentResponse(messages=[Message("assistant", ["done"])]) @@ -229,7 +229,7 @@ def _run_stream( self, messages: str | Message | list[str] | list[Message] | None = None, *, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: async def _iter() -> AsyncIterable[AgentResponseUpdate]: diff --git a/python/packages/devui/tests/devui/test_discovery.py b/python/packages/devui/tests/devui/test_discovery.py index d1f68c302f..4a4efaadab 100644 --- a/python/packages/devui/tests/devui/test_discovery.py +++ b/python/packages/devui/tests/devui/test_discovery.py @@ -74,14 +74,14 @@ async def test_discovery_accepts_agents_with_only_run(): init_file = agent_dir / "__init__.py" init_file.write_text(""" -from agent_framework import AgentResponse, AgentThread, Message, Role, Content +from agent_framework import AgentResponse, AgentSession, Message, Role, Content class NonStreamingAgent: id = "non_streaming" name = "Non-Streaming Agent" description = "Agent with run() method" - async def run(self, messages=None, *, thread=None, **kwargs): + async def run(self, messages=None, *, session=None, **kwargs): return AgentResponse( messages=[Message( role="assistant", @@ -90,8 +90,8 @@ async def run(self, messages=None, *, thread=None, **kwargs): response_id="test" ) - def get_new_thread(self, **kwargs): - return AgentThread() + def create_session(self, **kwargs): + return AgentSession() agent = NonStreamingAgent() """) @@ -188,19 +188,19 @@ def test_func(input: str) -> str: agent_dir = temp_path / "my_agent" agent_dir.mkdir() (agent_dir / "agent.py").write_text(""" -from agent_framework import AgentResponse, AgentThread, Message, Role, TextContent +from agent_framework import AgentResponse, AgentSession, Message, Role, TextContent class TestAgent: name = "Test Agent" - async def run(self, messages=None, *, thread=None, **kwargs): + async def run(self, messages=None, *, session=None, **kwargs): return AgentResponse( messages=[Message(role="assistant", contents=[Content.from_text(text="test")])], response_id="test" ) - def get_new_thread(self, **kwargs): - return AgentThread() + def create_session(self, **kwargs): + return AgentSession() agent = TestAgent() """) @@ -320,7 +320,7 @@ class WeatherAgent: name = "Weather Agent" description = "Gets weather information" - def run(self, input_str, *, stream: bool = False, thread=None, **kwargs): + def run(self, input_str, *, stream: bool = False, session=None, **kwargs): return f"Weather in {input_str}" """) diff --git a/python/packages/devui/tests/devui/test_execution.py b/python/packages/devui/tests/devui/test_execution.py index 3fff11ad79..a7ac622c75 100644 --- a/python/packages/devui/tests/devui/test_execution.py +++ b/python/packages/devui/tests/devui/test_execution.py @@ -538,7 +538,7 @@ def test_extract_workflow_hil_responses_handles_stringified_json(): async def test_executor_handles_streaming_agent(): """Test executor handles agents with run(stream=True) method.""" - from agent_framework import AgentResponse, AgentResponseUpdate, AgentThread, Content, Message + from agent_framework import AgentResponse, AgentResponseUpdate, AgentSession, Content, Message class StreamingAgent: """Agent with run() method supporting stream parameter.""" @@ -547,7 +547,7 @@ class StreamingAgent: name = "Streaming Test Agent" description = "Test agent with run(stream=True)" - def run(self, messages=None, *, stream=False, thread=None, **kwargs): + def run(self, messages=None, *, stream=False, session=None, **kwargs): if stream: # Return an async generator for streaming return self._stream_impl(messages) @@ -566,8 +566,8 @@ async def _stream_impl(self, messages): role="assistant", ) - def get_new_thread(self, **kwargs): - return AgentThread() + def create_session(self, **kwargs): + return AgentSession() # Create executor and register agent discovery = EntityDiscovery(None) @@ -754,7 +754,7 @@ class StreamingAgent: name = "Streaming Test Agent" description = "Test agent for streaming" - async def run(self, input_str, *, stream: bool = False, thread=None, **kwargs): + async def run(self, input_str, *, stream: bool = False, session=None, **kwargs): if stream: async def _stream(): for i, word in enumerate(f"Processing {input_str}".split()): diff --git a/python/packages/github_copilot/tests/test_github_copilot_agent.py b/python/packages/github_copilot/tests/test_github_copilot_agent.py index 0ce0ca0307..c9bca2d0a6 100644 --- a/python/packages/github_copilot/tests/test_github_copilot_agent.py +++ b/python/packages/github_copilot/tests/test_github_copilot_agent.py @@ -10,7 +10,7 @@ from agent_framework import ( AgentResponse, AgentResponseUpdate, - AgentThread, + AgentSession, Content, Message, ) @@ -300,21 +300,21 @@ async def test_run_chat_message( assert isinstance(response, AgentResponse) assert len(response.messages) == 1 - async def test_run_with_thread( + async def test_run_with_session( self, mock_client: MagicMock, mock_session: MagicMock, assistant_message_event: SessionEvent, ) -> None: - """Test run method with existing thread.""" + """Test run method with existing session.""" mock_session.send_and_wait.return_value = assistant_message_event agent = GitHubCopilotAgent(client=mock_client) - thread = AgentThread() - response = await agent.run("Hello", thread=thread) + session = AgentSession() + response = await agent.run("Hello", session=session) assert isinstance(response, AgentResponse) - assert thread.service_thread_id == mock_session.session_id + assert session.service_session_id == mock_session.session_id async def test_run_with_runtime_options( self, @@ -392,13 +392,13 @@ def mock_on(handler: Any) -> Any: assert responses[0].role == "assistant" assert responses[0].contents[0].text == "Hello" - async def test_run_streaming_with_thread( + async def test_run_streaming_with_session( self, mock_client: MagicMock, mock_session: MagicMock, session_idle_event: SessionEvent, ) -> None: - """Test streaming with existing thread.""" + """Test streaming with existing session.""" def mock_on(handler: Any) -> Any: handler(session_idle_event) @@ -407,12 +407,12 @@ def mock_on(handler: Any) -> Any: mock_session.on = mock_on agent = GitHubCopilotAgent(client=mock_client) - thread = AgentThread() + session = AgentSession() - async for _ in agent.run("Hello", thread=thread, stream=True): + async for _ in agent.run("Hello", session=session, stream=True): pass - assert thread.service_thread_id == mock_session.session_id + assert session.service_session_id == mock_session.session_id async def test_run_streaming_error( self, @@ -461,20 +461,20 @@ def mock_on(handler: Any) -> Any: class TestGitHubCopilotAgentSessionManagement: """Test cases for session management.""" - async def test_session_resumed_for_same_thread( + async def test_session_resumed_for_same_session( self, mock_client: MagicMock, mock_session: MagicMock, assistant_message_event: SessionEvent, ) -> None: - """Test that subsequent calls on the same thread resume the session.""" + """Test that subsequent calls on the same session resume the session.""" mock_session.send_and_wait.return_value = assistant_message_event agent = GitHubCopilotAgent(client=mock_client) - thread = AgentThread() + session = AgentSession() - await agent.run("Hello", thread=thread) - await agent.run("World", thread=thread) + await agent.run("Hello", session=session) + await agent.run("World", session=session) mock_client.create_session.assert_called_once() mock_client.resume_session.assert_called_once_with(mock_session.session_id, unittest.mock.ANY) @@ -490,7 +490,7 @@ async def test_session_config_includes_model( ) await agent.start() - await agent._get_or_create_session(AgentThread()) # type: ignore + await agent._get_or_create_session(AgentSession()) # type: ignore call_args = mock_client.create_session.call_args config = call_args[0][0] @@ -508,7 +508,7 @@ async def test_session_config_includes_instructions( ) await agent.start() - await agent._get_or_create_session(AgentThread()) # type: ignore + await agent._get_or_create_session(AgentSession()) # type: ignore call_args = mock_client.create_session.call_args config = call_args[0][0] @@ -531,7 +531,7 @@ async def test_runtime_options_take_precedence_over_default( "system_message": {"mode": "replace", "content": "Runtime instructions"} } await agent._get_or_create_session( # type: ignore - AgentThread(), + AgentSession(), runtime_options=runtime_options, ) @@ -549,25 +549,25 @@ async def test_session_config_includes_streaming_flag( agent = GitHubCopilotAgent(client=mock_client) await agent.start() - await agent._get_or_create_session(AgentThread(), streaming=True) # type: ignore + await agent._get_or_create_session(AgentSession(), streaming=True) # type: ignore call_args = mock_client.create_session.call_args config = call_args[0][0] assert config["streaming"] is True - async def test_resume_session_with_existing_service_thread_id( + async def test_resume_session_with_existing_service_session_id( self, mock_client: MagicMock, mock_session: MagicMock, ) -> None: - """Test that session is resumed when thread has a service_thread_id.""" + """Test that session is resumed when session has a service_session_id.""" agent = GitHubCopilotAgent(client=mock_client) await agent.start() - thread = AgentThread() - thread.service_thread_id = "existing-session-id" + session = AgentSession() + session.service_session_id = "existing-session-id" - await agent._get_or_create_session(thread) # type: ignore + await agent._get_or_create_session(session) # type: ignore mock_client.create_session.assert_not_called() mock_client.resume_session.assert_called_once() @@ -596,10 +596,10 @@ def my_tool(arg: str) -> str: ) await agent.start() - thread = AgentThread() - thread.service_thread_id = "existing-session-id" + session = AgentSession() + session.service_session_id = "existing-session-id" - await agent._get_or_create_session(thread) # type: ignore + await agent._get_or_create_session(session) # type: ignore mock_client.resume_session.assert_called_once() call_args = mock_client.resume_session.call_args @@ -639,7 +639,7 @@ async def test_mcp_servers_passed_to_create_session( ) await agent.start() - await agent._get_or_create_session(AgentThread()) # type: ignore + await agent._get_or_create_session(AgentSession()) # type: ignore call_args = mock_client.create_session.call_args config = call_args[0][0] @@ -672,10 +672,10 @@ async def test_mcp_servers_passed_to_resume_session( ) await agent.start() - thread = AgentThread() - thread.service_thread_id = "existing-session-id" + session = AgentSession() + session.service_session_id = "existing-session-id" - await agent._get_or_create_session(thread) # type: ignore + await agent._get_or_create_session(session) # type: ignore mock_client.resume_session.assert_called_once() call_args = mock_client.resume_session.call_args @@ -692,7 +692,7 @@ async def test_session_config_excludes_mcp_servers_when_not_set( agent = GitHubCopilotAgent(client=mock_client) await agent.start() - await agent._get_or_create_session(AgentThread()) # type: ignore + await agent._get_or_create_session(AgentSession()) # type: ignore call_args = mock_client.create_session.call_args config = call_args[0][0] @@ -716,7 +716,7 @@ def my_tool(arg: str) -> str: agent = GitHubCopilotAgent(client=mock_client, tools=[my_tool]) await agent.start() - await agent._get_or_create_session(AgentThread()) # type: ignore + await agent._get_or_create_session(AgentSession()) # type: ignore call_args = mock_client.create_session.call_args config = call_args[0][0] @@ -739,7 +739,7 @@ def my_tool(arg: str) -> str: agent = GitHubCopilotAgent(client=mock_client, tools=[my_tool]) await agent.start() - await agent._get_or_create_session(AgentThread()) # type: ignore + await agent._get_or_create_session(AgentSession()) # type: ignore call_args = mock_client.create_session.call_args config = call_args[0][0] @@ -764,7 +764,7 @@ def failing_tool(arg: str) -> str: agent = GitHubCopilotAgent(client=mock_client, tools=[failing_tool]) await agent.start() - await agent._get_or_create_session(AgentThread()) # type: ignore + await agent._get_or_create_session(AgentSession()) # type: ignore call_args = mock_client.create_session.call_args config = call_args[0][0] @@ -867,7 +867,7 @@ async def test_get_or_create_session_raises_on_create_error( await agent.start() with pytest.raises(ServiceException, match="Failed to create GitHub Copilot session"): - await agent._get_or_create_session(AgentThread()) # type: ignore + await agent._get_or_create_session(AgentSession()) # type: ignore async def test_get_or_create_session_raises_when_client_not_initialized(self) -> None: """Test that _get_or_create_session raises ServiceException when client is not initialized.""" @@ -875,7 +875,7 @@ async def test_get_or_create_session_raises_when_client_not_initialized(self) -> # Don't call start() - client remains None with pytest.raises(ServiceException, match="GitHub Copilot client not initialized"): - await agent._get_or_create_session(AgentThread()) # type: ignore + await agent._get_or_create_session(AgentSession()) # type: ignore class TestGitHubCopilotAgentPermissions: @@ -919,7 +919,7 @@ def approve_shell_read(request: PermissionRequest, context: dict[str, str]) -> P ) await agent.start() - await agent._get_or_create_session(AgentThread()) # type: ignore + await agent._get_or_create_session(AgentSession()) # type: ignore call_args = mock_client.create_session.call_args config = call_args[0][0] @@ -935,7 +935,7 @@ async def test_session_config_excludes_permission_handler_when_not_set( agent = GitHubCopilotAgent(client=mock_client) await agent.start() - await agent._get_or_create_session(AgentThread()) # type: ignore + await agent._get_or_create_session(AgentSession()) # type: ignore call_args = mock_client.create_session.call_args config = call_args[0][0] diff --git a/python/packages/mem0/tests/test_mem0_context_provider.py b/python/packages/mem0/tests/test_mem0_context_provider.py index f01d12053a..129f1bfa61 100644 --- a/python/packages/mem0/tests/test_mem0_context_provider.py +++ b/python/packages/mem0/tests/test_mem0_context_provider.py @@ -4,17 +4,19 @@ import importlib import os import sys +from typing import Any from unittest.mock import AsyncMock import pytest -from agent_framework import Content, Context, Message +from agent_framework import AgentResponse, Message +from agent_framework._sessions import AgentSession, SessionContext from agent_framework.exceptions import ServiceInitializationError -from agent_framework.mem0 import Mem0Provider +from agent_framework.mem0 import Mem0ContextProvider -def test_mem0_provider_import() -> None: - """Test that Mem0Provider can be imported.""" - assert Mem0Provider is not None +def test_mem0_context_provider_import() -> None: + """Test that Mem0ContextProvider can be imported.""" + assert Mem0ContextProvider is not None @pytest.fixture @@ -32,6 +34,18 @@ def mock_mem0_client() -> AsyncMock: return mock_client +@pytest.fixture +def mock_agent() -> AsyncMock: + """Create a mock agent.""" + return AsyncMock() + + +@pytest.fixture +def session() -> AgentSession: + """Create a test AgentSession.""" + return AgentSession(session_id="test-session") + + @pytest.fixture def sample_messages() -> list[Message]: """Create sample chat messages for testing.""" @@ -42,63 +56,63 @@ def sample_messages() -> list[Message]: ] +def _make_context(input_messages: list[Message], session_id: str = "test-session") -> SessionContext: + """Helper to create a SessionContext with the given input messages.""" + return SessionContext(session_id=session_id, input_messages=input_messages) + + +def _empty_state() -> dict[str, Any]: + """Helper to create an empty state dict.""" + return {} + + def test_init_with_all_ids(mock_mem0_client: AsyncMock) -> None: """Test initialization with all IDs provided.""" - provider = Mem0Provider( + provider = Mem0ContextProvider( + source_id="mem0", user_id="user123", agent_id="agent123", application_id="app123", - thread_id="thread123", mem0_client=mock_mem0_client, ) assert provider.user_id == "user123" assert provider.agent_id == "agent123" assert provider.application_id == "app123" - assert provider.thread_id == "thread123" def test_init_without_filters_succeeds(mock_mem0_client: AsyncMock) -> None: """Test that initialization succeeds even without filters (validation happens during invocation).""" - provider = Mem0Provider(mem0_client=mock_mem0_client) + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) assert provider.user_id is None assert provider.agent_id is None assert provider.application_id is None - assert provider.thread_id is None def test_init_with_custom_context_prompt(mock_mem0_client: AsyncMock) -> None: """Test initialization with custom context prompt.""" custom_prompt = "## Custom Memories\nConsider these memories:" - provider = Mem0Provider(user_id="user123", context_prompt=custom_prompt, mem0_client=mock_mem0_client) - assert provider.context_prompt == custom_prompt - - -def test_init_with_scope_to_per_operation_thread_id(mock_mem0_client: AsyncMock) -> None: - """Test initialization with scope_to_per_operation_thread_id enabled.""" - provider = Mem0Provider( - user_id="user123", - scope_to_per_operation_thread_id=True, - mem0_client=mock_mem0_client, + provider = Mem0ContextProvider( + source_id="mem0", user_id="user123", context_prompt=custom_prompt, mem0_client=mock_mem0_client ) - assert provider.scope_to_per_operation_thread_id is True + assert provider.context_prompt == custom_prompt def test_init_with_provided_client_should_not_close(mock_mem0_client: AsyncMock) -> None: """Test that provided client should not be closed by provider.""" - provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client) + provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) assert provider._should_close_client is False async def test_async_context_manager_entry(mock_mem0_client: AsyncMock) -> None: """Test async context manager entry returns self.""" - provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client) + provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) async with provider as ctx: assert ctx is provider async def test_async_context_manager_exit_does_not_close_provided_client(mock_mem0_client: AsyncMock) -> None: """Test that async context manager does not close provided client.""" - provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client) + provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) assert provider._should_close_client is False async with provider: @@ -107,82 +121,47 @@ async def test_async_context_manager_exit_does_not_close_provided_client(mock_me mock_mem0_client.__aexit__.assert_not_called() -class TestMem0ProviderThreadMethods: - """Test thread lifecycle methods.""" - - async def test_thread_created_sets_per_operation_thread_id(self, mock_mem0_client: AsyncMock) -> None: - """Test that thread_created sets per-operation thread ID.""" - provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client) - - await provider.thread_created("thread123") - - assert provider._per_operation_thread_id == "thread123" - - async def test_thread_created_with_existing_thread_id(self, mock_mem0_client: AsyncMock) -> None: - """Test thread_created when thread ID already exists.""" - provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client) - provider._per_operation_thread_id = "existing_thread" - - await provider.thread_created("thread123") - - # Should not overwrite existing thread ID - assert provider._per_operation_thread_id == "existing_thread" - - async def test_thread_created_validation_with_scope_enabled(self, mock_mem0_client: AsyncMock) -> None: - """Test thread_created validation when scope_to_per_operation_thread_id is enabled.""" - provider = Mem0Provider( - user_id="user123", - scope_to_per_operation_thread_id=True, - mem0_client=mock_mem0_client, - ) - provider._per_operation_thread_id = "existing_thread" - - with pytest.raises(ValueError) as exc_info: - await provider.thread_created("different_thread") - - assert "can only be used with one thread at a time" in str(exc_info.value) - - async def test_messages_adding_sets_per_operation_thread_id(self, mock_mem0_client: AsyncMock) -> None: - """Test that invoked sets per-operation thread ID.""" - provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client) +class TestMem0ContextProviderAfterRun: + """Test after_run method (storing messages to Mem0).""" - await provider.thread_created("thread123") - - assert provider._per_operation_thread_id == "thread123" - - -class TestMem0ProviderMessagesAdding: - """Test invoked method.""" - - async def test_messages_adding_fails_without_filters(self, mock_mem0_client: AsyncMock) -> None: - """Test that invoked fails when no filters are provided.""" - provider = Mem0Provider(mem0_client=mock_mem0_client) - message = Message(role="user", text="Hello!") + async def test_after_run_fails_without_filters( + self, mock_mem0_client: AsyncMock, mock_agent: AsyncMock, session: AgentSession + ) -> None: + """Test that after_run fails when no filters are provided.""" + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) + ctx = _make_context([Message(role="user", text="Hello!")]) with pytest.raises(ServiceInitializationError) as exc_info: - await provider.invoked(message) + await provider.after_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) assert "At least one of the filters" in str(exc_info.value) - async def test_messages_adding_single_message(self, mock_mem0_client: AsyncMock) -> None: - """Test adding a single message.""" - provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client) - message = Message(role="user", text="Hello!") + async def test_after_run_single_input_message( + self, mock_mem0_client: AsyncMock, mock_agent: AsyncMock, session: AgentSession + ) -> None: + """Test storing a single input message.""" + provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) + ctx = _make_context([Message(role="user", text="Hello!")]) - await provider.invoked(message) + await provider.after_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) mock_mem0_client.add.assert_called_once() call_args = mock_mem0_client.add.call_args assert call_args.kwargs["messages"] == [{"role": "user", "content": "Hello!"}] assert call_args.kwargs["user_id"] == "user123" - async def test_messages_adding_multiple_messages( - self, mock_mem0_client: AsyncMock, sample_messages: list[Message] + async def test_after_run_multiple_messages( + self, + mock_mem0_client: AsyncMock, + mock_agent: AsyncMock, + session: AgentSession, + sample_messages: list[Message], ) -> None: - """Test adding multiple messages.""" - provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client) + """Test storing multiple input messages.""" + provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) + ctx = _make_context(sample_messages) - await provider.invoked(sample_messages) + await provider.after_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) mock_mem0_client.add.assert_called_once() call_args = mock_mem0_client.add.call_args @@ -193,308 +172,308 @@ async def test_messages_adding_multiple_messages( ] assert call_args.kwargs["messages"] == expected_messages - async def test_messages_adding_with_agent_id( - self, mock_mem0_client: AsyncMock, sample_messages: list[Message] + async def test_after_run_includes_response_messages( + self, mock_mem0_client: AsyncMock, mock_agent: AsyncMock, session: AgentSession ) -> None: - """Test adding messages with agent_id.""" - provider = Mem0Provider(agent_id="agent123", mem0_client=mock_mem0_client) + """Test that after_run includes response messages.""" + provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) + ctx = _make_context([Message(role="user", text="Hello!")]) + ctx._response = AgentResponse(messages=[Message(role="assistant", text="Hi there!")]) - await provider.invoked(sample_messages) + await provider.after_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) + mock_mem0_client.add.assert_called_once() call_args = mock_mem0_client.add.call_args - assert call_args.kwargs["agent_id"] == "agent123" - assert call_args.kwargs["user_id"] is None + expected_messages = [ + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Hi there!"}, + ] + assert call_args.kwargs["messages"] == expected_messages - async def test_messages_adding_with_application_id( - self, mock_mem0_client: AsyncMock, sample_messages: list[Message] + async def test_after_run_with_agent_id( + self, + mock_mem0_client: AsyncMock, + mock_agent: AsyncMock, + session: AgentSession, + sample_messages: list[Message], ) -> None: - """Test adding messages with application_id in metadata.""" - provider = Mem0Provider(user_id="user123", application_id="app123", mem0_client=mock_mem0_client) + """Test storing messages with agent_id.""" + provider = Mem0ContextProvider(source_id="mem0", agent_id="agent123", mem0_client=mock_mem0_client) + ctx = _make_context(sample_messages) - await provider.invoked(sample_messages) + await provider.after_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) call_args = mock_mem0_client.add.call_args - assert call_args.kwargs["metadata"] == {"application_id": "app123"} + assert call_args.kwargs["agent_id"] == "agent123" + assert call_args.kwargs["user_id"] is None - async def test_messages_adding_with_scope_to_per_operation_thread_id( - self, mock_mem0_client: AsyncMock, sample_messages: list[Message] + async def test_after_run_with_application_id( + self, + mock_mem0_client: AsyncMock, + mock_agent: AsyncMock, + session: AgentSession, + sample_messages: list[Message], ) -> None: - """Test adding messages with scope_to_per_operation_thread_id enabled.""" - provider = Mem0Provider( - user_id="user123", - thread_id="base_thread", - scope_to_per_operation_thread_id=True, - mem0_client=mock_mem0_client, + """Test storing messages with application_id in metadata.""" + provider = Mem0ContextProvider( + source_id="mem0", user_id="user123", application_id="app123", mem0_client=mock_mem0_client ) - provider._per_operation_thread_id = "operation_thread" + ctx = _make_context(sample_messages) - await provider.thread_created(thread_id="operation_thread") - await provider.invoked(sample_messages) + await provider.after_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) call_args = mock_mem0_client.add.call_args - assert call_args.kwargs["run_id"] == "operation_thread" + assert call_args.kwargs["metadata"] == {"application_id": "app123"} - async def test_messages_adding_without_scope_uses_base_thread_id( - self, mock_mem0_client: AsyncMock, sample_messages: list[Message] + async def test_after_run_uses_session_id_as_run_id( + self, + mock_mem0_client: AsyncMock, + mock_agent: AsyncMock, + session: AgentSession, + sample_messages: list[Message], ) -> None: - """Test adding messages without scope uses base thread_id.""" - provider = Mem0Provider( - user_id="user123", - thread_id="base_thread", - scope_to_per_operation_thread_id=False, - mem0_client=mock_mem0_client, - ) + """Test that after_run uses the context session_id as run_id.""" + provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) + ctx = _make_context(sample_messages, session_id="my-session") - await provider.invoked(sample_messages) + await provider.after_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) call_args = mock_mem0_client.add.call_args - assert call_args.kwargs["run_id"] == "base_thread" + assert call_args.kwargs["run_id"] == "my-session" - async def test_messages_adding_filters_empty_messages(self, mock_mem0_client: AsyncMock) -> None: + async def test_after_run_filters_empty_messages( + self, mock_mem0_client: AsyncMock, mock_agent: AsyncMock, session: AgentSession + ) -> None: """Test that empty or invalid messages are filtered out.""" - provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client) + provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) messages = [ - Message(role="user", text=""), # Empty text - Message(role="user", text=" "), # Whitespace only + Message(role="user", text=""), + Message(role="user", text=" "), Message(role="user", text="Valid message"), ] + ctx = _make_context(messages) - await provider.invoked(messages) + await provider.after_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) call_args = mock_mem0_client.add.call_args - # Should only include the valid message assert call_args.kwargs["messages"] == [{"role": "user", "content": "Valid message"}] - async def test_messages_adding_skips_when_no_valid_messages(self, mock_mem0_client: AsyncMock) -> None: + async def test_after_run_skips_when_no_valid_messages( + self, mock_mem0_client: AsyncMock, mock_agent: AsyncMock, session: AgentSession + ) -> None: """Test that mem0 client is not called when no valid messages exist.""" - provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client) + provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) messages = [ Message(role="user", text=""), Message(role="user", text=" "), ] + ctx = _make_context(messages) - await provider.invoked(messages) + await provider.after_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) mock_mem0_client.add.assert_not_called() -class TestMem0ProviderModelInvoking: - """Test invoking method.""" +class TestMem0ContextProviderBeforeRun: + """Test before_run method (searching memories and adding to context).""" - async def test_model_invoking_fails_without_filters(self, mock_mem0_client: AsyncMock) -> None: - """Test that invoking fails when no filters are provided.""" - provider = Mem0Provider(mem0_client=mock_mem0_client) - message = Message(role="user", text="What's the weather?") + async def test_before_run_fails_without_filters( + self, mock_mem0_client: AsyncMock, mock_agent: AsyncMock, session: AgentSession + ) -> None: + """Test that before_run fails when no filters are provided.""" + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) + ctx = _make_context([Message(role="user", text="What's the weather?")]) with pytest.raises(ServiceInitializationError) as exc_info: - await provider.invoking(message) + await provider.before_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) assert "At least one of the filters" in str(exc_info.value) - async def test_model_invoking_single_message(self, mock_mem0_client: AsyncMock) -> None: - """Test invoking with a single message.""" - provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client) - message = Message(role="user", text="What's the weather?") + async def test_before_run_single_message( + self, mock_mem0_client: AsyncMock, mock_agent: AsyncMock, session: AgentSession + ) -> None: + """Test before_run with a single input message.""" + provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) + ctx = _make_context([Message(role="user", text="What's the weather?")]) - # Mock search results mock_mem0_client.search.return_value = [ {"memory": "User likes outdoor activities"}, {"memory": "User lives in Seattle"}, ] - context = await provider.invoking(message) + await provider.before_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) mock_mem0_client.search.assert_called_once() call_args = mock_mem0_client.search.call_args assert call_args.kwargs["query"] == "What's the weather?" - assert call_args.kwargs["filters"] == {"user_id": "user123"} + assert call_args.kwargs["filters"] == {"user_id": "user123", "run_id": "test-session"} - assert isinstance(context, Context) - expected_instructions = ( + context_messages = ctx.get_messages() + assert len(context_messages) > 0 + expected_text = ( "## Memories\nConsider the following memories when answering user questions:\n" "User likes outdoor activities\nUser lives in Seattle" ) - - assert context.messages - assert context.messages[0].text == expected_instructions - - async def test_model_invoking_multiple_messages( - self, mock_mem0_client: AsyncMock, sample_messages: list[Message] + assert context_messages[0].text == expected_text + + async def test_before_run_multiple_messages( + self, + mock_mem0_client: AsyncMock, + mock_agent: AsyncMock, + session: AgentSession, + sample_messages: list[Message], ) -> None: - """Test invoking with multiple messages.""" - provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client) + """Test before_run with multiple input messages.""" + provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) + ctx = _make_context(sample_messages) mock_mem0_client.search.return_value = [{"memory": "Previous conversation context"}] - await provider.invoking(sample_messages) + await provider.before_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) call_args = mock_mem0_client.search.call_args expected_query = "Hello, how are you?\nI'm doing well, thank you!\nYou are a helpful assistant" assert call_args.kwargs["query"] == expected_query - async def test_model_invoking_with_agent_id(self, mock_mem0_client: AsyncMock) -> None: - """Test invoking with agent_id.""" - provider = Mem0Provider(agent_id="agent123", mem0_client=mock_mem0_client) - message = Message(role="user", text="Hello") + async def test_before_run_with_agent_id( + self, mock_mem0_client: AsyncMock, mock_agent: AsyncMock, session: AgentSession + ) -> None: + """Test before_run with agent_id.""" + provider = Mem0ContextProvider(source_id="mem0", agent_id="agent123", mem0_client=mock_mem0_client) + ctx = _make_context([Message(role="user", text="Hello")]) mock_mem0_client.search.return_value = [] - await provider.invoking(message) + await provider.before_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) call_args = mock_mem0_client.search.call_args - assert call_args.kwargs["filters"] == {"agent_id": "agent123"} + assert call_args.kwargs["filters"] == {"agent_id": "agent123", "run_id": "test-session"} - async def test_model_invoking_with_scope_to_per_operation_thread_id(self, mock_mem0_client: AsyncMock) -> None: - """Test invoking with scope_to_per_operation_thread_id enabled.""" - provider = Mem0Provider( - user_id="user123", - thread_id="base_thread", - scope_to_per_operation_thread_id=True, - mem0_client=mock_mem0_client, - ) - provider._per_operation_thread_id = "operation_thread" - message = Message(role="user", text="Hello") + async def test_before_run_with_session_id_in_filters( + self, mock_mem0_client: AsyncMock, mock_agent: AsyncMock, session: AgentSession + ) -> None: + """Test before_run includes session_id as run_id in search filters.""" + provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) + ctx = _make_context([Message(role="user", text="Hello")], session_id="my-session") mock_mem0_client.search.return_value = [] - await provider.invoking(message) + await provider.before_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) call_args = mock_mem0_client.search.call_args - assert call_args.kwargs["filters"] == {"user_id": "user123", "run_id": "operation_thread"} + assert call_args.kwargs["filters"] == {"user_id": "user123", "run_id": "my-session"} - async def test_model_invoking_no_memories_returns_none_instructions(self, mock_mem0_client: AsyncMock) -> None: - """Test that no memories returns context with None instructions.""" - provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client) - message = Message(role="user", text="Hello") + async def test_before_run_no_memories_does_not_add_messages( + self, mock_mem0_client: AsyncMock, mock_agent: AsyncMock, session: AgentSession + ) -> None: + """Test that no memories does not add context messages.""" + provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) + ctx = _make_context([Message(role="user", text="Hello")]) mock_mem0_client.search.return_value = [] - context = await provider.invoking(message) + await provider.before_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) - assert isinstance(context, Context) - assert not context.messages + context_messages = ctx.get_messages() + assert len(context_messages) == 0 - async def test_model_invoking_function_approval_response_returns_none_instructions( - self, mock_mem0_client: AsyncMock + async def test_before_run_empty_input_text_skips_search( + self, mock_mem0_client: AsyncMock, mock_agent: AsyncMock, session: AgentSession ) -> None: - """Test invoking with function approval response content messages returns context with None instructions.""" - - provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client) - function_call = Content.from_function_call(call_id="1", name="test_func", arguments='{"arg1": "value1"}') - message = Message( - role="user", - contents=[ - Content.from_function_approval_response( - id="approval_1", - function_call=function_call, - approved=True, - ) - ], - ) - - mock_mem0_client.search.return_value = [] + """Test that empty input text skips the search entirely.""" + provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) + ctx = _make_context([Message(role="user", text=""), Message(role="user", text=" ")]) - context = await provider.invoking(message) + await provider.before_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) - assert isinstance(context, Context) - assert not context.messages + mock_mem0_client.search.assert_not_called() - async def test_model_invoking_filters_empty_message_text(self, mock_mem0_client: AsyncMock) -> None: + async def test_before_run_filters_empty_message_text( + self, mock_mem0_client: AsyncMock, mock_agent: AsyncMock, session: AgentSession + ) -> None: """Test that empty message text is filtered out from query.""" - provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client) + provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) messages = [ Message(role="user", text=""), Message(role="user", text="Valid message"), Message(role="user", text=" "), ] + ctx = _make_context(messages) mock_mem0_client.search.return_value = [] - await provider.invoking(messages) + await provider.before_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) call_args = mock_mem0_client.search.call_args assert call_args.kwargs["query"] == "Valid message" - async def test_model_invoking_custom_context_prompt(self, mock_mem0_client: AsyncMock) -> None: - """Test invoking with custom context prompt.""" + async def test_before_run_custom_context_prompt( + self, mock_mem0_client: AsyncMock, mock_agent: AsyncMock, session: AgentSession + ) -> None: + """Test before_run with custom context prompt.""" custom_prompt = "## Custom Context\nRemember these details:" - provider = Mem0Provider( + provider = Mem0ContextProvider( + source_id="mem0", user_id="user123", context_prompt=custom_prompt, mem0_client=mock_mem0_client, ) - message = Message(role="user", text="Hello") + ctx = _make_context([Message(role="user", text="Hello")]) mock_mem0_client.search.return_value = [{"memory": "Test memory"}] - context = await provider.invoking(message) + await provider.before_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) - expected_instructions = "## Custom Context\nRemember these details:\nTest memory" - assert context.messages - assert context.messages[0].text == expected_instructions + context_messages = ctx.get_messages() + expected_text = "## Custom Context\nRemember these details:\nTest memory" + assert len(context_messages) > 0 + assert context_messages[0].text == expected_text -class TestMem0ProviderValidation: +class TestMem0ContextProviderValidation: """Test validation methods.""" - def test_validate_per_operation_thread_id_success(self, mock_mem0_client: AsyncMock) -> None: - """Test successful validation of per-operation thread ID.""" - provider = Mem0Provider( - user_id="user123", - scope_to_per_operation_thread_id=True, - mem0_client=mock_mem0_client, - ) - provider._per_operation_thread_id = "thread123" + def test_validate_filters_fails_without_any_filter(self, mock_mem0_client: AsyncMock) -> None: + """Test validation failure when no filters are set.""" + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) - # Should not raise exception for same thread ID - provider._validate_per_operation_thread_id("thread123") - - # Should not raise exception for None - provider._validate_per_operation_thread_id(None) + with pytest.raises(ServiceInitializationError) as exc_info: + provider._validate_filters() - def test_validate_per_operation_thread_id_failure(self, mock_mem0_client: AsyncMock) -> None: - """Test validation failure for conflicting thread IDs.""" - provider = Mem0Provider( - user_id="user123", - scope_to_per_operation_thread_id=True, - mem0_client=mock_mem0_client, - ) - provider._per_operation_thread_id = "thread123" + assert "At least one of the filters" in str(exc_info.value) - with pytest.raises(ValueError) as exc_info: - provider._validate_per_operation_thread_id("different_thread") + def test_validate_filters_succeeds_with_user_id(self, mock_mem0_client: AsyncMock) -> None: + """Test validation succeeds with user_id set.""" + provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) + provider._validate_filters() # Should not raise - assert "can only be used with one thread at a time" in str(exc_info.value) + def test_validate_filters_succeeds_with_agent_id(self, mock_mem0_client: AsyncMock) -> None: + """Test validation succeeds with agent_id set.""" + provider = Mem0ContextProvider(source_id="mem0", agent_id="agent123", mem0_client=mock_mem0_client) + provider._validate_filters() # Should not raise - def test_validate_per_operation_thread_id_disabled_scope(self, mock_mem0_client: AsyncMock) -> None: - """Test that validation is skipped when scope is disabled.""" - provider = Mem0Provider( - user_id="user123", - scope_to_per_operation_thread_id=False, - mem0_client=mock_mem0_client, - ) - provider._per_operation_thread_id = "thread123" + def test_validate_filters_succeeds_with_application_id(self, mock_mem0_client: AsyncMock) -> None: + """Test validation succeeds with application_id set.""" + provider = Mem0ContextProvider(source_id="mem0", application_id="app123", mem0_client=mock_mem0_client) + provider._validate_filters() # Should not raise - # Should not raise exception even with different thread ID - provider._validate_per_operation_thread_id("different_thread") - -class TestMem0ProviderBuildFilters: +class TestMem0ContextProviderBuildFilters: """Test the _build_filters method.""" def test_build_filters_with_user_id_only(self, mock_mem0_client: AsyncMock) -> None: """Test building filters with only user_id.""" - provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client) + provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) filters = provider._build_filters() assert filters == {"user_id": "user123"} def test_build_filters_with_all_parameters(self, mock_mem0_client: AsyncMock) -> None: """Test building filters with all initialization parameters.""" - provider = Mem0Provider( + provider = Mem0ContextProvider( + source_id="mem0", user_id="user123", agent_id="agent456", - thread_id="thread789", application_id="app999", mem0_client=mock_mem0_client, ) @@ -503,16 +482,15 @@ def test_build_filters_with_all_parameters(self, mock_mem0_client: AsyncMock) -> assert filters == { "user_id": "user123", "agent_id": "agent456", - "run_id": "thread789", "app_id": "app999", } def test_build_filters_excludes_none_values(self, mock_mem0_client: AsyncMock) -> None: """Test that None values are excluded from filters.""" - provider = Mem0Provider( + provider = Mem0ContextProvider( + source_id="mem0", user_id="user123", agent_id=None, - thread_id=None, application_id=None, mem0_client=mock_mem0_client, ) @@ -520,44 +498,25 @@ def test_build_filters_excludes_none_values(self, mock_mem0_client: AsyncMock) - filters = provider._build_filters() assert filters == {"user_id": "user123"} assert "agent_id" not in filters - assert "run_id" not in filters assert "app_id" not in filters - def test_build_filters_with_per_operation_thread_id(self, mock_mem0_client: AsyncMock) -> None: - """Test that per-operation thread ID takes precedence over base thread_id.""" - provider = Mem0Provider( + def test_build_filters_with_session_id(self, mock_mem0_client: AsyncMock) -> None: + """Test that session_id is included as run_id in filters.""" + provider = Mem0ContextProvider( + source_id="mem0", user_id="user123", - thread_id="base_thread", - scope_to_per_operation_thread_id=True, mem0_client=mock_mem0_client, ) - provider._per_operation_thread_id = "operation_thread" - - filters = provider._build_filters() - assert filters == { - "user_id": "user123", - "run_id": "operation_thread", # Per-operation thread, not base_thread - } - def test_build_filters_uses_base_thread_when_no_per_operation(self, mock_mem0_client: AsyncMock) -> None: - """Test that base thread_id is used when per-operation thread is not set.""" - provider = Mem0Provider( - user_id="user123", - thread_id="base_thread", - scope_to_per_operation_thread_id=True, - mem0_client=mock_mem0_client, - ) - # _per_operation_thread_id is None - - filters = provider._build_filters() + filters = provider._build_filters(session_id="session-123") assert filters == { "user_id": "user123", - "run_id": "base_thread", # Falls back to base thread_id + "run_id": "session-123", } def test_build_filters_returns_empty_dict_when_no_parameters(self, mock_mem0_client: AsyncMock) -> None: """Test that _build_filters returns an empty dict when no parameters are set.""" - provider = Mem0Provider(mem0_client=mock_mem0_client) + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) filters = provider._build_filters() assert filters == {} diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py index ce365724ef..3bc07cd67c 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py @@ -222,7 +222,7 @@ def __init__( autonomous_mode_turn_limit: Maximum number of autonomous turns before requesting user input. """ cloned_agent = self._prepare_agent_with_handoffs(agent, handoffs) - super().__init__(cloned_agent, agent_thread=agent_thread) + super().__init__(cloned_agent, session=agent_thread) self._handoff_targets = {handoff.target_id for handoff in handoffs} self._termination_condition = termination_condition @@ -306,8 +306,7 @@ def _clone_chat_agent(self, agent: Agent) -> Agent: id=agent.id, name=agent.name, description=agent.description, - chat_message_store_factory=agent.chat_message_store_factory, - context_provider=agent.context_provider, + context_providers=agent.context_providers, middleware=middleware, default_options=cloned_options, # type: ignore[arg-type] ) diff --git a/python/packages/orchestrations/tests/test_group_chat.py b/python/packages/orchestrations/tests/test_group_chat.py index 79bc62d6c5..59bae72314 100644 --- a/python/packages/orchestrations/tests/test_group_chat.py +++ b/python/packages/orchestrations/tests/test_group_chat.py @@ -9,7 +9,7 @@ AgentExecutorResponse, AgentResponse, AgentResponseUpdate, - AgentThread, + AgentSession, BaseAgent, ChatResponse, ChatResponseUpdate, @@ -41,7 +41,7 @@ def run( # type: ignore[override] messages: str | Message | Sequence[str | Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse] | AsyncIterable[AgentResponseUpdate]: if stream: @@ -78,7 +78,7 @@ async def run( self, messages: str | Message | Sequence[str | Message] | None = None, *, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> AgentResponse: if self._call_count == 0: @@ -898,7 +898,7 @@ async def run( self, messages: str | Message | Sequence[str | Message] | None = None, *, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> AgentResponse: if self._call_count == 0: diff --git a/python/packages/orchestrations/tests/test_handoff.py b/python/packages/orchestrations/tests/test_handoff.py index 38ff6ea49a..fe43fe4387 100644 --- a/python/packages/orchestrations/tests/test_handoff.py +++ b/python/packages/orchestrations/tests/test_handoff.py @@ -11,7 +11,7 @@ ChatResponseUpdate, Content, Context, - ContextProvider, + BaseContextProvider, Message, ResponseStream, WorkflowEvent, @@ -306,16 +306,18 @@ async def mock_get_response(messages: Any, options: dict[str, Any] | None = None async def test_context_provider_preserved_during_handoff(): - """Verify that context_provider is preserved when cloning agents in handoff workflows.""" + """Verify that context_providers are preserved when cloning agents in handoff workflows.""" # Track whether context provider methods were called provider_calls: list[str] = [] - class TestContextProvider(ContextProvider): + class TestContextProvider(BaseContextProvider): """A test context provider that tracks its invocations.""" - async def invoking(self, messages: Sequence[Message], **kwargs: Any) -> Context: - provider_calls.append("invoking") - return Context(instructions="Test context from provider.") + def __init__(self) -> None: + super().__init__("test") + + async def before_run(self, **kwargs: Any) -> None: + provider_calls.append("before_run") # Create context provider context_provider = TestContextProvider() @@ -328,13 +330,13 @@ async def invoking(self, messages: Sequence[Message], **kwargs: Any) -> Context: client=mock_client, name="test_agent", id="test_agent", - context_provider=context_provider, + context_providers=[context_provider], ) # Verify the original agent has the context provider - assert agent.context_provider is context_provider, "Original agent should have context provider" + assert context_provider in agent.context_providers, "Original agent should have context provider" - # Build handoff workflow - this should clone the agent and preserve context_provider + # Build handoff workflow - this should clone the agent and preserve context_providers workflow = HandoffBuilder(participants=[agent]).with_start_agent(agent).build() # Run workflow with a simple message to trigger context provider diff --git a/python/packages/orchestrations/tests/test_magentic.py b/python/packages/orchestrations/tests/test_magentic.py index c887833545..25c068b9fe 100644 --- a/python/packages/orchestrations/tests/test_magentic.py +++ b/python/packages/orchestrations/tests/test_magentic.py @@ -9,7 +9,7 @@ from agent_framework import ( AgentResponse, AgentResponseUpdate, - AgentThread, + AgentSession, BaseAgent, Content, Executor, @@ -153,7 +153,7 @@ def run( # type: ignore[override] messages: str | Message | Sequence[str | Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse] | AsyncIterable[AgentResponseUpdate]: if stream: @@ -414,7 +414,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: bool = False, - thread: Any = None, + session: Any = None, **kwargs: Any, ) -> Awaitable[AgentResponse] | AsyncIterable[AgentResponseUpdate]: if stream: @@ -526,7 +526,7 @@ class StubThreadAgent(BaseAgent): def __init__(self, name: str | None = None) -> None: super().__init__(name=name or "agentA") - def run(self, messages=None, *, stream: bool = False, thread=None, **kwargs): # type: ignore[override] + def run(self, messages=None, *, stream: bool = False, session=None, **kwargs): # type: ignore[override] if stream: return self._run_stream() @@ -554,7 +554,7 @@ def __init__(self) -> None: super().__init__(name="agentA") self.client = StubAssistantsClient() # type name contains 'AssistantsClient' - def run(self, messages=None, *, stream: bool = False, thread=None, **kwargs): # type: ignore[override] + def run(self, messages=None, *, stream: bool = False, session=None, **kwargs): # type: ignore[override] if stream: return self._run_stream() diff --git a/python/packages/orchestrations/tests/test_orchestration_request_info.py b/python/packages/orchestrations/tests/test_orchestration_request_info.py index 1e2b8a4af6..7d0acbc945 100644 --- a/python/packages/orchestrations/tests/test_orchestration_request_info.py +++ b/python/packages/orchestrations/tests/test_orchestration_request_info.py @@ -10,7 +10,7 @@ from agent_framework import ( AgentResponse, AgentResponseUpdate, - AgentThread, + AgentSession, Message, SupportsAgentRun, ) @@ -203,7 +203,7 @@ async def run( messages: str | Message | list[str] | list[Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + thread: AgentSession | None = None, **kwargs: Any, ) -> AgentResponse | AsyncIterable[AgentResponseUpdate]: """Dummy run method.""" @@ -214,9 +214,9 @@ async def run( async def _run_stream_impl(self) -> AsyncIterable[AgentResponseUpdate]: yield AgentResponseUpdate(messages=[Message(role="assistant", text="Test response stream")]) - def get_new_thread(self, **kwargs: Any) -> AgentThread: - """Creates a new conversation thread for the agent.""" - return AgentThread(**kwargs) + def create_session(self, **kwargs: Any) -> AgentSession: + """Creates a new conversation session for the agent.""" + return AgentSession(**kwargs) class TestAgentApprovalExecutor: diff --git a/python/packages/orchestrations/tests/test_sequential.py b/python/packages/orchestrations/tests/test_sequential.py index 04a4ae4141..67bcc1bb9e 100644 --- a/python/packages/orchestrations/tests/test_sequential.py +++ b/python/packages/orchestrations/tests/test_sequential.py @@ -8,7 +8,7 @@ AgentExecutorResponse, AgentResponse, AgentResponseUpdate, - AgentThread, + AgentSession, BaseAgent, Content, Executor, @@ -30,7 +30,7 @@ def run( # type: ignore[override] messages: str | Message | list[str] | list[Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse] | AsyncIterable[AgentResponseUpdate]: if stream: diff --git a/python/packages/purview/tests/purview/test_middleware.py b/python/packages/purview/tests/purview/test_middleware.py index 451aaf9df7..3a34c48344 100644 --- a/python/packages/purview/tests/purview/test_middleware.py +++ b/python/packages/purview/tests/purview/test_middleware.py @@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from agent_framework import AgentContext, AgentResponse, AgentThread, Message, MiddlewareTermination +from agent_framework import AgentContext, AgentResponse, AgentSession, Message, MiddlewareTermination from azure.core.credentials import AccessToken from agent_framework_purview import PurviewPolicyMiddleware, PurviewSettings @@ -337,12 +337,12 @@ async def mock_next(): with pytest.raises(ValueError, match="Test error"): await middleware.process(context, mock_next) - async def test_middleware_uses_thread_service_thread_id_as_session_id( + async def test_middleware_uses_session_service_session_id_as_session_id( self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock ) -> None: - """Test that session_id is extracted from thread.service_thread_id.""" - thread = AgentThread(service_thread_id="thread-123") - context = AgentContext(agent=mock_agent, messages=[Message(role="user", text="Hello")], thread=thread) + """Test that session_id is extracted from session.service_session_id.""" + session = AgentSession(service_session_id="thread-123") + context = AgentContext(agent=mock_agent, messages=[Message(role="user", text="Hello")], session=session) with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: @@ -373,13 +373,13 @@ async def mock_next() -> None: assert mock_proc.call_count == 2 mock_proc.assert_any_call(messages, Activity.UPLOAD_TEXT, session_id="conv-456") - async def test_middleware_thread_id_takes_precedence_over_message_conversation_id( + async def test_middleware_session_id_takes_precedence_over_message_conversation_id( self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock ) -> None: - """Test that thread.service_thread_id takes precedence over message conversation_id.""" - thread = AgentThread(service_thread_id="thread-789") + """Test that session.service_session_id takes precedence over message conversation_id.""" + session = AgentSession(service_session_id="thread-789") messages = [Message(role="user", text="Hello", additional_properties={"conversation_id": "conv-456"})] - context = AgentContext(agent=mock_agent, messages=messages, thread=thread) + context = AgentContext(agent=mock_agent, messages=messages, session=session) with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: @@ -388,13 +388,13 @@ async def mock_next() -> None: await middleware.process(context, mock_next) - # Verify thread ID is used, not message conversation_id + # Verify session ID is used, not message conversation_id mock_proc.assert_any_call(messages, Activity.UPLOAD_TEXT, session_id="thread-789") async def test_middleware_passes_none_session_id_when_not_available( self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock ) -> None: - """Test that session_id is None when no thread or conversation_id is available.""" + """Test that session_id is None when no session or conversation_id is available.""" context = AgentContext(agent=mock_agent, messages=[Message(role="user", text="Hello")]) with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: @@ -411,8 +411,8 @@ async def test_middleware_session_id_used_in_post_check( self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock ) -> None: """Test that session_id is passed to post-check process_messages call.""" - thread = AgentThread(service_thread_id="thread-999") - context = AgentContext(agent=mock_agent, messages=[Message(role="user", text="Hello")], thread=thread) + session = AgentSession(service_session_id="thread-999") + context = AgentContext(agent=mock_agent, messages=[Message(role="user", text="Hello")], session=session) with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: diff --git a/python/packages/redis/tests/test_redis_chat_message_store.py b/python/packages/redis/tests/test_redis_chat_message_store.py deleted file mode 100644 index 99a3038870..0000000000 --- a/python/packages/redis/tests/test_redis_chat_message_store.py +++ /dev/null @@ -1,621 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from agent_framework import Content, Message - -from agent_framework_redis import RedisChatMessageStore - - -class TestRedisChatMessageStore: - """Unit tests for RedisChatMessageStore using mocked Redis client. - - These tests use mocked Redis operations to verify the logic and behavior - of the RedisChatMessageStore without requiring a real Redis server. - """ - - @pytest.fixture - def sample_messages(self): - """Sample chat messages for testing.""" - return [ - Message(role="user", text="Hello", message_id="msg1"), - Message(role="assistant", text="Hi there!", message_id="msg2"), - Message(role="user", text="How are you?", message_id="msg3"), - ] - - @pytest.fixture - def mock_redis_client(self): - """Mock Redis client with all required methods.""" - client = MagicMock() - # Core list operations - client.lrange = AsyncMock(return_value=[]) - client.llen = AsyncMock(return_value=0) - client.lindex = AsyncMock(return_value=None) - client.lset = AsyncMock(return_value=True) - client.lrem = AsyncMock(return_value=0) - client.lpop = AsyncMock(return_value=None) - client.rpop = AsyncMock(return_value=None) - client.ltrim = AsyncMock(return_value=True) - client.delete = AsyncMock(return_value=1) - - # Pipeline operations - mock_pipeline = AsyncMock() - mock_pipeline.rpush = AsyncMock() - mock_pipeline.execute = AsyncMock() - client.pipeline.return_value.__aenter__.return_value = mock_pipeline - - return client - - @pytest.fixture - def redis_store(self, mock_redis_client): - """Redis chat message store with mocked client.""" - with patch("agent_framework_redis._chat_message_store.redis.from_url") as mock_from_url: - mock_from_url.return_value = mock_redis_client - store = RedisChatMessageStore(redis_url="redis://localhost:6379", thread_id="test_thread_123") - store._redis_client = mock_redis_client - return store - - def test_init_with_thread_id(self): - """Test initialization with explicit thread ID.""" - thread_id = "user123_session456" - with patch("agent_framework_redis._chat_message_store.redis.from_url"): - store = RedisChatMessageStore(redis_url="redis://localhost:6379", thread_id=thread_id) - - assert store.thread_id == thread_id - assert store.redis_url == "redis://localhost:6379" - assert store.key_prefix == "chat_messages" - assert store.redis_key == f"chat_messages:{thread_id}" - - def test_init_auto_generate_thread_id(self): - """Test initialization with auto-generated thread ID.""" - with patch("agent_framework_redis._chat_message_store.redis.from_url"): - store = RedisChatMessageStore(redis_url="redis://localhost:6379") - - assert store.thread_id is not None - assert store.thread_id.startswith("thread_") - assert len(store.thread_id) > 10 # Should be a UUID - - def test_init_with_custom_prefix(self): - """Test initialization with custom key prefix.""" - with patch("agent_framework_redis._chat_message_store.redis.from_url"): - store = RedisChatMessageStore( - redis_url="redis://localhost:6379", thread_id="test123", key_prefix="custom_messages" - ) - - assert store.redis_key == "custom_messages:test123" - - def test_init_with_max_messages(self): - """Test initialization with message limit.""" - with patch("agent_framework_redis._chat_message_store.redis.from_url"): - store = RedisChatMessageStore(redis_url="redis://localhost:6379", thread_id="test123", max_messages=100) - - assert store.max_messages == 100 - - def test_init_with_redis_url_required(self): - """Test that either redis_url or credential_provider is required.""" - with pytest.raises(ValueError, match="Either redis_url or credential_provider must be provided"): - RedisChatMessageStore(thread_id="test123") - - def test_init_with_credential_provider(self): - """Test initialization with credential_provider.""" - mock_credential_provider = MagicMock() - - with patch("agent_framework_redis._chat_message_store.redis.Redis") as mock_redis_class: - mock_redis_instance = MagicMock() - mock_redis_class.return_value = mock_redis_instance - - store = RedisChatMessageStore( - credential_provider=mock_credential_provider, - host="myredis.redis.cache.windows.net", - thread_id="test123", - ) - - # Verify Redis.Redis was called with correct parameters - mock_redis_class.assert_called_once_with( - host="myredis.redis.cache.windows.net", - port=6380, - ssl=True, - username=None, - credential_provider=mock_credential_provider, - decode_responses=True, - ) - # Verify store instance is properly initialized - assert store.thread_id == "test123" - assert store.redis_url is None # Should be None for credential provider auth - assert store.key_prefix == "chat_messages" - assert store.max_messages is None - - def test_init_with_credential_provider_custom_port(self): - """Test initialization with credential_provider and custom port.""" - mock_credential_provider = MagicMock() - - with patch("agent_framework_redis._chat_message_store.redis.Redis") as mock_redis_class: - mock_redis_instance = MagicMock() - mock_redis_class.return_value = mock_redis_instance - - store = RedisChatMessageStore( - credential_provider=mock_credential_provider, - host="myredis.redis.cache.windows.net", - port=6379, - ssl=False, - username="admin", - thread_id="test123", - ) - - # Verify custom parameters were passed - mock_redis_class.assert_called_once_with( - host="myredis.redis.cache.windows.net", - port=6379, - ssl=False, - username="admin", - credential_provider=mock_credential_provider, - decode_responses=True, - ) - # Verify store instance is properly initialized - assert store.thread_id == "test123" - assert store.redis_url is None # Should be None for credential provider auth - assert store.key_prefix == "chat_messages" - - def test_init_credential_provider_requires_host(self): - """Test that credential_provider requires host parameter.""" - mock_credential_provider = MagicMock() - - with pytest.raises(ValueError, match="host is required when using credential_provider"): - RedisChatMessageStore( - credential_provider=mock_credential_provider, - thread_id="test123", - ) - - def test_init_mutually_exclusive_params(self): - """Test that redis_url and credential_provider are mutually exclusive.""" - mock_credential_provider = MagicMock() - - with pytest.raises(ValueError, match="redis_url and credential_provider are mutually exclusive"): - RedisChatMessageStore( - redis_url="redis://localhost:6379", - credential_provider=mock_credential_provider, - host="myredis.redis.cache.windows.net", - thread_id="test123", - ) - - async def test_serialize_with_credential_provider(self): - """Test that serialization works correctly with credential provider authentication.""" - mock_credential_provider = MagicMock() - - with patch("agent_framework_redis._chat_message_store.redis.Redis") as mock_redis_class: - mock_redis_instance = MagicMock() - mock_redis_class.return_value = mock_redis_instance - - store = RedisChatMessageStore( - credential_provider=mock_credential_provider, - host="myredis.redis.cache.windows.net", - thread_id="test123", - key_prefix="custom_prefix", - max_messages=100, - ) - - # Serialize the store state - state = await store.serialize() - - # Verify serialization includes correct values - assert state["thread_id"] == "test123" - assert state["redis_url"] is None # Should be None for credential provider auth - assert state["key_prefix"] == "custom_prefix" - assert state["max_messages"] == 100 - assert state["type"] == "redis_store_state" - - def test_init_with_initial_messages(self, sample_messages): - """Test initialization with initial messages.""" - with patch("agent_framework_redis._chat_message_store.redis.from_url"): - store = RedisChatMessageStore( - redis_url="redis://localhost:6379", thread_id="test123", messages=sample_messages - ) - - assert store._initial_messages == sample_messages - - async def test_add_messages_single(self, redis_store, mock_redis_client, sample_messages): - """Test adding a single message using pipeline operations.""" - message = sample_messages[0] - - await redis_store.add_messages([message]) - - # Verify pipeline operations were called - mock_redis_client.pipeline.assert_called_with(transaction=True) - - # Get the pipeline mock and verify it was used correctly - pipeline_mock = mock_redis_client.pipeline.return_value.__aenter__.return_value - pipeline_mock.rpush.assert_called() - pipeline_mock.execute.assert_called() - - async def test_add_messages_multiple(self, redis_store, mock_redis_client, sample_messages): - """Test adding multiple messages using pipeline operations.""" - await redis_store.add_messages(sample_messages) - - # Verify pipeline operations - mock_redis_client.pipeline.assert_called_with(transaction=True) - - # Verify rpush was called for each message - pipeline_mock = mock_redis_client.pipeline.return_value.__aenter__.return_value - assert pipeline_mock.rpush.call_count == len(sample_messages) - - async def test_add_messages_with_max_limit(self, mock_redis_client): - """Test adding messages with max limit triggers trimming.""" - with patch("agent_framework_redis._chat_message_store.redis.from_url") as mock_from_url: - mock_from_url.return_value = mock_redis_client - - # Mock llen to return count that exceeds limit after adding - mock_redis_client.llen.return_value = 5 - - store = RedisChatMessageStore(redis_url="redis://localhost:6379", thread_id="test123", max_messages=3) - store._redis_client = mock_redis_client - - message = Message(role="user", text="Test") - await store.add_messages([message]) - - # Should trim after adding to keep only last 3 messages - mock_redis_client.ltrim.assert_called_once_with("chat_messages:test123", -3, -1) - - async def test_list_messages_empty(self, redis_store, mock_redis_client): - """Test listing messages when store is empty.""" - mock_redis_client.lrange.return_value = [] - - messages = await redis_store.list_messages() - - assert messages == [] - mock_redis_client.lrange.assert_called_once_with("chat_messages:test_thread_123", 0, -1) - - async def test_list_messages_with_data(self, redis_store, mock_redis_client, sample_messages): - """Test listing messages with data in Redis.""" - # Create proper serialized messages using the actual serialization method - test_messages = [ - Message(role="user", text="Hello", message_id="msg1"), - Message(role="assistant", text="Hi there!", message_id="msg2"), - ] - serialized_messages = [redis_store._serialize_message(msg) for msg in test_messages] - mock_redis_client.lrange.return_value = serialized_messages - - messages = await redis_store.list_messages() - - assert len(messages) == 2 - assert messages[0].role == "user" - assert messages[0].text == "Hello" - assert messages[1].role == "assistant" - assert messages[1].text == "Hi there!" - - async def test_list_messages_with_initial_messages(self, sample_messages): - """Test that initial messages are added to Redis and retrieved correctly.""" - with patch("agent_framework_redis._chat_message_store.redis.from_url") as mock_from_url: - mock_redis_client = MagicMock() - mock_redis_client.llen = AsyncMock(return_value=0) # Redis key is empty - mock_redis_client.lrange = AsyncMock(return_value=[]) - - # Mock pipeline for adding initial messages - mock_pipeline = AsyncMock() - mock_pipeline.rpush = AsyncMock() - mock_pipeline.execute = AsyncMock() - mock_redis_client.pipeline.return_value.__aenter__.return_value = mock_pipeline - - mock_from_url.return_value = mock_redis_client - - store = RedisChatMessageStore( - redis_url="redis://localhost:6379", - thread_id="test123", - messages=sample_messages[:1], # One initial message - ) - store._redis_client = mock_redis_client - - # Mock Redis to return the initial message after it's added - initial_message_json = store._serialize_message(sample_messages[0]) - mock_redis_client.lrange.return_value = [initial_message_json] - - messages = await store.list_messages() - - assert len(messages) == 1 - assert messages[0].text == "Hello" - # Verify initial message was added to Redis via pipeline - mock_pipeline.rpush.assert_called() - - async def test_initial_messages_not_added_if_key_exists(self, sample_messages): - """Test that initial messages are not added if Redis key already has data.""" - with patch("agent_framework_redis._chat_message_store.redis.from_url") as mock_from_url: - mock_redis_client = MagicMock() - mock_redis_client.llen = AsyncMock(return_value=5) # Key already has messages - mock_redis_client.lrange = AsyncMock(return_value=[]) - - # Pipeline should not be called since key already exists - mock_pipeline = AsyncMock() - mock_pipeline.rpush = AsyncMock() - mock_pipeline.execute = AsyncMock() - mock_redis_client.pipeline.return_value.__aenter__.return_value = mock_pipeline - - mock_from_url.return_value = mock_redis_client - - store = RedisChatMessageStore( - redis_url="redis://localhost:6379", - thread_id="test123", - messages=sample_messages[:1], # One initial message - ) - store._redis_client = mock_redis_client - - await store.list_messages() - - # Should check length but not add messages since key exists - mock_redis_client.llen.assert_called() - mock_pipeline.rpush.assert_not_called() - - async def test_serialize_state(self, redis_store): - """Test state serialization.""" - state = await redis_store.serialize() - - expected_state = { - "type": "redis_store_state", - "thread_id": "test_thread_123", - "redis_url": "redis://localhost:6379", - "key_prefix": "chat_messages", - "max_messages": None, - } - - assert state == expected_state - - async def test_deserialize_state(self, redis_store): - """Test state deserialization.""" - serialized_state = { - "thread_id": "restored_thread_456", - "redis_url": "redis://localhost:6380", - "key_prefix": "restored_messages", - "max_messages": 50, - } - - await redis_store.update_from_state(serialized_state) - - assert redis_store.thread_id == "restored_thread_456" - assert redis_store.redis_url == "redis://localhost:6380" - assert redis_store.key_prefix == "restored_messages" - assert redis_store.max_messages == 50 - - async def test_deserialize_state_empty(self, redis_store): - """Test deserializing empty state doesn't change anything.""" - original_thread_id = redis_store.thread_id - - await redis_store.update_from_state(None) - - assert redis_store.thread_id == original_thread_id - - async def test_clear_messages(self, redis_store, mock_redis_client): - """Test clearing all messages.""" - await redis_store.clear() - - mock_redis_client.delete.assert_called_once_with("chat_messages:test_thread_123") - - async def test_message_serialization_roundtrip(self, sample_messages): - """Test message serialization and deserialization roundtrip.""" - with patch("agent_framework_redis._chat_message_store.redis.from_url"): - store = RedisChatMessageStore(redis_url="redis://localhost:6379", thread_id="test123") - - message = sample_messages[0] - - # Test serialization - serialized = store._serialize_message(message) - assert isinstance(serialized, str) - - # Test deserialization - deserialized = store._deserialize_message(serialized) - assert deserialized.role == message.role - assert deserialized.text == message.text - assert deserialized.message_id == message.message_id - - async def test_message_serialization_with_complex_content(self): - """Test serialization of messages with complex content.""" - with patch("agent_framework_redis._chat_message_store.redis.from_url"): - store = RedisChatMessageStore(redis_url="redis://localhost:6379", thread_id="test123") - - # Message with multiple content types - message = Message( - role="assistant", - contents=[Content.from_text(text="Hello"), Content.from_text(text="World")], - author_name="TestBot", - message_id="complex_msg", - additional_properties={"metadata": "test"}, - ) - - serialized = store._serialize_message(message) - deserialized = store._deserialize_message(serialized) - - assert deserialized.role == "assistant" - assert deserialized.text == "Hello World" - assert deserialized.author_name == "TestBot" - assert deserialized.message_id == "complex_msg" - assert deserialized.additional_properties == {"metadata": "test"} - - async def test_redis_connection_error_handling(self): - """Test handling Redis connection errors in add_messages.""" - with patch("agent_framework_redis._chat_message_store.redis.from_url") as mock_from_url: - mock_client = MagicMock() - - # Mock pipeline to raise exception during execution - mock_pipeline = AsyncMock() - mock_pipeline.rpush = AsyncMock() - mock_pipeline.execute = AsyncMock(side_effect=Exception("Connection failed")) - mock_client.pipeline.return_value.__aenter__.return_value = mock_pipeline - - mock_from_url.return_value = mock_client - - store = RedisChatMessageStore(redis_url="redis://localhost:6379", thread_id="test123") - store._redis_client = mock_client - - message = Message(role="user", text="Test") - - # Should propagate Redis connection errors - with pytest.raises(Exception, match="Connection failed"): - await store.add_messages([message]) - - async def test_getitem(self, redis_store, mock_redis_client, sample_messages): - """Test getitem method using Redis LINDEX.""" - # Mock LINDEX to return specific messages - serialized_msg0 = redis_store._serialize_message(sample_messages[0]) - serialized_msg1 = redis_store._serialize_message(sample_messages[1]) - - def mock_lindex(key, index): - if index == 0: - return serialized_msg0 - if index == -1 or index == 1: - return serialized_msg1 - return None - - mock_redis_client.lindex = AsyncMock(side_effect=mock_lindex) - - # Test positive index - message = await redis_store.getitem(0) - assert message.text == "Hello" - - # Test negative index - message = await redis_store.getitem(-1) - assert message.text == "Hi there!" - - async def test_getitem_index_error(self, redis_store, mock_redis_client): - """Test getitem raises IndexError for invalid index.""" - mock_redis_client.lindex = AsyncMock(return_value=None) - - with pytest.raises(IndexError): - await redis_store.getitem(0) - - async def test_setitem(self, redis_store, mock_redis_client, sample_messages): - """Test setitem method using Redis LSET.""" - mock_redis_client.llen.return_value = 2 - mock_redis_client.lset = AsyncMock() - - new_message = Message(role="user", text="Updated message") - await redis_store.setitem(0, new_message) - - mock_redis_client.lset.assert_called_once() - call_args = mock_redis_client.lset.call_args - assert call_args[0][0] == "chat_messages:test_thread_123" - assert call_args[0][1] == 0 - - async def test_setitem_index_error(self, redis_store, mock_redis_client): - """Test setitem raises IndexError for invalid index.""" - mock_redis_client.llen.return_value = 0 - - new_message = Message(role="user", text="Test") - with pytest.raises(IndexError): - await redis_store.setitem(0, new_message) - - async def test_append(self, redis_store, mock_redis_client): - """Test append method delegates to add_messages.""" - message = Message(role="user", text="Appended message") - await redis_store.append(message) - - # Should call pipeline operations via add_messages - mock_redis_client.pipeline.assert_called_with(transaction=True) - - # Verify the message was added via pipeline - pipeline_mock = mock_redis_client.pipeline.return_value.__aenter__.return_value - pipeline_mock.rpush.assert_called() - pipeline_mock.execute.assert_called() - - async def test_count(self, redis_store, mock_redis_client): - """Test count method.""" - mock_redis_client.llen.return_value = 5 - - count = await redis_store.count() - - assert count == 5 - mock_redis_client.llen.assert_called_with("chat_messages:test_thread_123") - - async def test_len_method(self, redis_store, mock_redis_client): - """Test async __len__ method.""" - mock_redis_client.llen.return_value = 3 - - length = await redis_store.__len__() - - assert length == 3 - mock_redis_client.llen.assert_called_with("chat_messages:test_thread_123") - - def test_bool_method(self, redis_store): - """Test __bool__ method always returns True.""" - # Store should always be truthy - assert bool(redis_store) is True - assert redis_store.__bool__() is True - - # Should work in if statements (this is what Agent Framework uses) - if redis_store: - assert True # Should reach this - else: - raise AssertionError("Store should be truthy") - - async def test_index_found(self, redis_store, mock_redis_client, sample_messages): - """Test index method when message is found using Redis LINDEX.""" - mock_redis_client.llen.return_value = 2 - - # Mock LINDEX to return messages at each position - serialized_msg0 = redis_store._serialize_message(sample_messages[0]) - serialized_msg1 = redis_store._serialize_message(sample_messages[1]) - - def mock_lindex(key, index): - if index == 0: - return serialized_msg0 - if index == 1: - return serialized_msg1 - return None - - mock_redis_client.lindex = AsyncMock(side_effect=mock_lindex) - - index = await redis_store.index(sample_messages[1]) - assert index == 1 - - # Should have called lindex twice (index 0, then index 1) - assert mock_redis_client.lindex.call_count == 2 - - async def test_index_not_found(self, redis_store, mock_redis_client, sample_messages): - """Test index method when message is not found.""" - mock_redis_client.llen.return_value = 1 - mock_redis_client.lindex = AsyncMock(return_value="different_message") - - with pytest.raises(ValueError, match="Message not found in store"): - await redis_store.index(sample_messages[0]) - - async def test_remove(self, redis_store, mock_redis_client, sample_messages): - """Test remove method using Redis LREM.""" - mock_redis_client.lrem = AsyncMock(return_value=1) # 1 element removed - - await redis_store.remove(sample_messages[0]) - - # Should use LREM to remove the message - expected_serialized = redis_store._serialize_message(sample_messages[0]) - mock_redis_client.lrem.assert_called_once_with("chat_messages:test_thread_123", 1, expected_serialized) - - async def test_remove_not_found(self, redis_store, mock_redis_client, sample_messages): - """Test remove method when message is not found.""" - mock_redis_client.lrem = AsyncMock(return_value=0) # 0 elements removed - - with pytest.raises(ValueError, match="Message not found in store"): - await redis_store.remove(sample_messages[0]) - - async def test_extend(self, redis_store, mock_redis_client, sample_messages): - """Test extend method delegates to add_messages.""" - await redis_store.extend(sample_messages[:2]) - - # Should call pipeline operations via add_messages - mock_redis_client.pipeline.assert_called_with(transaction=True) - - # Verify rpush was called for each message - pipeline_mock = mock_redis_client.pipeline.return_value.__aenter__.return_value - assert pipeline_mock.rpush.call_count >= 2 - - async def test_serialize_with_agent_thread(self, redis_store, sample_messages): - """Test that RedisChatMessageStore can be serialized within an AgentThread. - - This test verifies the fix for issue #1991 where calling thread.serialize() - with a RedisChatMessageStore would fail with "Messages should be a list" error. - """ - from agent_framework import AgentThread - - thread = AgentThread(message_store=redis_store) - await thread.on_new_messages(sample_messages) - - serialized = await thread.serialize() - - assert serialized is not None - assert "chat_message_store_state" in serialized - assert serialized["chat_message_store_state"] is not None diff --git a/python/packages/redis/tests/test_redis_provider.py b/python/packages/redis/tests/test_redis_provider.py deleted file mode 100644 index 8e842b3de7..0000000000 --- a/python/packages/redis/tests/test_redis_provider.py +++ /dev/null @@ -1,425 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -from typing import Any -from unittest.mock import AsyncMock, MagicMock, patch - -import numpy as np -import pytest -from agent_framework import Message -from agent_framework.exceptions import AgentException, ServiceInitializationError -from redisvl.utils.vectorize import CustomTextVectorizer - -from agent_framework_redis import RedisProvider - -CUSTOM_VECTORIZER = CustomTextVectorizer(embed=lambda x: [1.0, 2.0, 3.0], dtype="float32") - - -@pytest.fixture -def mock_index() -> AsyncMock: - idx = AsyncMock() - idx.create = AsyncMock() - idx.load = AsyncMock() - idx.query = AsyncMock() - idx.exists = AsyncMock(return_value=False) - - async def _paginate_generator(*_args: Any, **_kwargs: Any): - # Default empty generator; override per-test as needed - if False: # pragma: no cover - yield [] - return - - idx.paginate = _paginate_generator - return idx - - -@pytest.fixture -def patch_index_from_dict(mock_index: AsyncMock): - with patch("agent_framework_redis._provider.AsyncSearchIndex") as mock_cls: - mock_cls.from_dict = MagicMock(return_value=mock_index) - - # Mock from_existing to return a mock with matching schema by default - # This prevents schema validation errors in tests that don't specifically test schema validation - async def mock_from_existing(index_name, redis_url): - mock_existing = AsyncMock() - # Return a schema that will match whatever the provider generates - # This is a bit of a hack, but allows existing tests to continue working - mock_existing.schema.to_dict = MagicMock( - side_effect=lambda: mock_cls.from_dict.call_args[0][0] if mock_cls.from_dict.call_args else {} - ) - return mock_existing - - mock_cls.from_existing = AsyncMock(side_effect=mock_from_existing) - - yield mock_cls - - -@pytest.fixture -def patch_queries(): - calls: dict[str, Any] = {"TextQuery": [], "HybridQuery": [], "FilterExpression": []} - - def _mk_query(kind: str): - class _Q: # simple marker object with captured kwargs - def __init__(self, **kwargs): - self.kind = kind - self.kwargs = kwargs - - return _Q - - with ( - patch( - "agent_framework_redis._provider.TextQuery", - side_effect=lambda **k: calls["TextQuery"].append(k) or _mk_query("text")(**k), - ) as text_q, - patch( - "agent_framework_redis._provider.HybridQuery", - side_effect=lambda **k: calls["HybridQuery"].append(k) or _mk_query("hybrid")(**k), - ) as hybrid_q, - patch( - "agent_framework_redis._provider.FilterExpression", - side_effect=lambda s: calls["FilterExpression"].append(s) or ("FE", s), - ) as filt, - ): - yield {"calls": calls, "TextQuery": text_q, "HybridQuery": hybrid_q, "FilterExpression": filt} - - -class TestRedisProviderInitialization: - # Verifies the provider can be imported from the package - def test_import(self): - from agent_framework_redis._provider import RedisProvider - - assert RedisProvider is not None - - # Constructing without filters should not raise; filters are enforced at call-time - def test_init_without_filters_ok(self, patch_index_from_dict): # noqa: ARG002 - provider = RedisProvider() - assert provider.user_id is None - assert provider.agent_id is None - assert provider.application_id is None - assert provider.thread_id is None - - # Schema should omit vector field when no vector configuration is provided - def test_schema_without_vector_field(self, patch_index_from_dict): - RedisProvider(user_id="u1") - # Inspect schema passed to from_dict - args, kwargs = patch_index_from_dict.from_dict.call_args - schema = args[0] - assert isinstance(schema, dict) - names = [f["name"] for f in schema["fields"]] - types = [f["type"] for f in schema["fields"]] - assert "content" in names - assert "text" in types - assert "vector" not in types - - -class TestRedisProviderMessages: - @pytest.fixture - def sample_messages(self) -> list[Message]: - return [ - Message(role="user", text="Hello, how are you?"), - Message(role="assistant", text="I'm doing well, thank you!"), - Message(role="system", text="You are a helpful assistant"), - ] - - # Writes require at least one scoping filter to avoid unbounded operations - async def test_messages_adding_requires_filters(self, patch_index_from_dict): # noqa: ARG002 - provider = RedisProvider() - with pytest.raises(ServiceInitializationError): - await provider.invoked("thread123", Message(role="user", text="Hello")) - - # Captures the per-operation thread id when provided - async def test_thread_created_sets_per_operation_id(self, patch_index_from_dict): # noqa: ARG002 - provider = RedisProvider(user_id="u1") - await provider.thread_created("t1") - assert provider._per_operation_thread_id == "t1" - - # Enforces single-thread usage when scope_to_per_operation_thread_id is True - async def test_thread_created_conflict_when_scoped(self, patch_index_from_dict): # noqa: ARG002 - provider = RedisProvider(user_id="u1", scope_to_per_operation_thread_id=True) - provider._per_operation_thread_id = "t1" - with pytest.raises(ValueError) as exc: - await provider.thread_created("t2") - assert "only be used with one thread" in str(exc.value) - - # Aggregates all results from the async paginator into a flat list - async def test_search_all_paginates(self, mock_index: AsyncMock, patch_index_from_dict): # noqa: ARG002 - async def gen(_q, page_size: int = 200): # noqa: ARG001, ANN001 - yield [{"id": 1}] - yield [{"id": 2}, {"id": 3}] - - mock_index.paginate = gen - provider = RedisProvider(user_id="u1") - res = await provider.search_all(page_size=2) - assert res == [{"id": 1}, {"id": 2}, {"id": 3}] - - -class TestRedisProviderModelInvoking: - # Reads require at least one scoping filter to avoid unbounded operations - async def test_model_invoking_requires_filters(self, patch_index_from_dict): # noqa: ARG002 - provider = RedisProvider() - with pytest.raises(ServiceInitializationError): - await provider.invoking(Message(role="user", text="Hi")) - - # Ensures text-only search path is used and context is composed from hits - async def test_textquery_path_and_context_contents( - self, mock_index: AsyncMock, patch_index_from_dict, patch_queries - ): # noqa: ARG002 - # Arrange: text-only search - mock_index.query = AsyncMock(return_value=[{"content": "A"}, {"content": "B"}]) - provider = RedisProvider(user_id="u1") - - # Act - ctx = await provider.invoking([Message(role="user", text="q1")]) - - # Assert: TextQuery used (not HybridQuery), filter_expression included - assert patch_queries["TextQuery"].call_count == 1 - assert patch_queries["HybridQuery"].call_count == 0 - kwargs = patch_queries["calls"]["TextQuery"][0] - assert kwargs["text"] == "q1" - assert kwargs["text_field_name"] == "content" - assert kwargs["num_results"] == 10 - assert "filter_expression" in kwargs - - # Context contains memories joined after the default prompt - assert ctx.messages is not None and len(ctx.messages) == 1 - text = ctx.messages[0].text - assert text.endswith("A\nB") - - # When no results are returned, Context should have no contents - async def test_model_invoking_empty_results_returns_empty_context( - self, mock_index: AsyncMock, patch_index_from_dict, patch_queries - ): # noqa: ARG002 - mock_index.query = AsyncMock(return_value=[]) - provider = RedisProvider(user_id="u1") - ctx = await provider.invoking([Message(role="user", text="any")]) - assert ctx.messages == [] - - # Ensures hybrid vector-text search is used when a vectorizer and vector field are configured - async def test_hybridquery_path_with_vectorizer(self, mock_index: AsyncMock, patch_index_from_dict, patch_queries): # noqa: ARG002 - mock_index.query = AsyncMock(return_value=[{"content": "Hit"}]) - provider = RedisProvider(user_id="u1", redis_vectorizer=CUSTOM_VECTORIZER, vector_field_name="vec") - - ctx = await provider.invoking([Message(role="user", text="hello")]) - - # Assert: HybridQuery used with vector and vector field - assert patch_queries["HybridQuery"].call_count == 1 - k = patch_queries["calls"]["HybridQuery"][0] - assert k["text"] == "hello" - assert k["vector_field_name"] == "vec" - assert k["vector"] == [1.0, 2.0, 3.0] - assert k["dtype"] == "float32" - assert k["num_results"] == 10 - assert "filter_expression" in k - - # Context assembled from returned memories - assert ctx.messages and "Hit" in ctx.messages[0].text - - -class TestRedisProviderContextManager: - # Verifies async context manager returns self for chaining - async def test_async_context_manager_returns_self(self, patch_index_from_dict): # noqa: ARG002 - provider = RedisProvider(user_id="u1") - async with provider as ctx: - assert ctx is provider - - # Exit should be a no-op and not raise - async def test_aexit_noop(self, patch_index_from_dict): # noqa: ARG002 - provider = RedisProvider(user_id="u1") - assert await provider.__aexit__(None, None, None) is None - - -class TestMessagesAddingBehavior: - # Adds messages while injecting partition defaults and preserving allowed roles - async def test_messages_adding_adds_partition_defaults_and_roles( - self, mock_index: AsyncMock, patch_index_from_dict - ): # noqa: ARG002 - provider = RedisProvider( - application_id="app", - agent_id="agent", - user_id="u1", - scope_to_per_operation_thread_id=True, - ) - - msgs = [ - Message(role="user", text="u"), - Message(role="assistant", text="a"), - Message(role="system", text="s"), - ] - - await provider.invoked(msgs) - - # Ensure load invoked with shaped docs containing defaults - assert mock_index.load.await_count == 1 - (loaded_args, _kwargs) = mock_index.load.call_args - docs = loaded_args[0] - assert isinstance(docs, list) and len(docs) == 3 - for d in docs: - assert d["role"] in {"user", "assistant", "system"} - assert d["content"] in {"u", "a", "s"} - assert d["application_id"] == "app" - assert d["agent_id"] == "agent" - assert d["user_id"] == "u1" - - # Skips blank text and disallowed roles (e.g., TOOL) when adding messages - async def test_messages_adding_ignores_blank_and_disallowed_roles( - self, mock_index: AsyncMock, patch_index_from_dict - ): # noqa: ARG002 - provider = RedisProvider(user_id="u1", scope_to_per_operation_thread_id=True) - msgs = [ - Message(role="user", text=" "), - Message(role="tool", text="tool output"), - ] - await provider.invoked(msgs) - # No valid messages -> no load - assert mock_index.load.await_count == 0 - - -class TestIndexCreationPublicCalls: - # Ensures index is created only once when drop=True on first public write call - async def test_messages_adding_triggers_index_create_once_when_drop_true( - self, mock_index: AsyncMock, patch_index_from_dict - ): # noqa: ARG002 - provider = RedisProvider(user_id="u1") - await provider.invoked(Message(role="user", text="m1")) - await provider.invoked(Message(role="user", text="m2")) - # create only on first call - assert mock_index.create.await_count == 1 - - # Ensures index is created when drop=False and the index does not exist on first read - async def test_model_invoking_triggers_create_when_drop_false_and_not_exists( - self, mock_index: AsyncMock, patch_index_from_dict - ): # noqa: ARG002 - mock_index.exists = AsyncMock(return_value=False) - provider = RedisProvider(user_id="u1") - mock_index.query = AsyncMock(return_value=[{"content": "C"}]) - await provider.invoking([Message(role="user", text="q")]) - assert mock_index.create.await_count == 1 - - -class TestThreadCreatedAdditional: - # Allows None or same thread id repeatedly; different id raises when scoped - async def test_thread_created_allows_none_and_same_id(self, patch_index_from_dict): # noqa: ARG002 - provider = RedisProvider(user_id="u1", scope_to_per_operation_thread_id=True) - # None is allowed - await provider.thread_created(None) - # Same id is allowed repeatedly - await provider.thread_created("t1") - await provider.thread_created("t1") - # Different id should raise - with pytest.raises(ValueError): - await provider.thread_created("t2") - - -class TestVectorPopulation: - # When vectorizer configured, invoked should embed content and populate the vector field - async def test_messages_adding_populates_vector_field_when_vectorizer_present( - self, mock_index: AsyncMock, patch_index_from_dict - ): # noqa: ARG002 - provider = RedisProvider( - user_id="u1", - scope_to_per_operation_thread_id=True, - redis_vectorizer=CUSTOM_VECTORIZER, - vector_field_name="vec", - ) - - await provider.invoked(Message(role="user", text="hello")) - assert mock_index.load.await_count == 1 - (loaded_args, _kwargs) = mock_index.load.call_args - docs = loaded_args[0] - assert isinstance(docs, list) and len(docs) == 1 - vec = docs[0].get("vec") - assert isinstance(vec, (bytes, bytearray)) - assert len(vec) == 3 * np.dtype(np.float32).itemsize - - -class TestRedisProviderSchemaVectors: - # Adds a vector field when vectorizer supplies dims implicitly - def test_schema_with_vector_field_and_dims_inferred(self, patch_index_from_dict): # noqa: ARG002 - RedisProvider(user_id="u1", redis_vectorizer=CUSTOM_VECTORIZER, vector_field_name="vec") - args, _ = patch_index_from_dict.from_dict.call_args - schema = args[0] - names = [f["name"] for f in schema["fields"]] - types = {f["name"]: f["type"] for f in schema["fields"]} - assert "vec" in names - assert types["vec"] == "vector" - - # Raises when redis_vectorizer is not the correct type - def test_init_invalid_vectorizer(self, patch_index_from_dict): # noqa: ARG002 - class DummyVectorizer: - pass - - with pytest.raises(AgentException): - RedisProvider(user_id="u1", redis_vectorizer=DummyVectorizer(), vector_field_name="vec") - - -class TestEnsureIndex: - # Creates index once and marks _index_initialized to prevent duplicate calls - async def test_ensure_index_creates_once(self, mock_index: AsyncMock, patch_index_from_dict): # noqa: ARG002 - # Mock index doesn't exist, so it will be created - mock_index.exists = AsyncMock(return_value=False) - provider = RedisProvider(user_id="u1", overwrite_index=False) - - assert provider._index_initialized is False - await provider._ensure_index() - assert mock_index.create.await_count == 1 - assert provider._index_initialized is True - - # Second call should not create again due to _index_initialized flag - await provider._ensure_index() - assert mock_index.create.await_count == 1 - - # Creates index with overwrite=True when overwrite_index=True - async def test_ensure_index_with_overwrite_true(self, mock_index: AsyncMock, patch_index_from_dict): # noqa: ARG002 - mock_index.exists = AsyncMock(return_value=True) - provider = RedisProvider(user_id="u1", overwrite_index=True) - - await provider._ensure_index() - - # Should call create with overwrite=True, drop=False - mock_index.create.assert_called_once_with(overwrite=True, drop=False) - - # Creates index with overwrite=False when index doesn't exist - async def test_ensure_index_create_if_missing(self, mock_index: AsyncMock, patch_index_from_dict): # noqa: ARG002 - mock_index.exists = AsyncMock(return_value=False) - provider = RedisProvider(user_id="u1", overwrite_index=False) - - await provider._ensure_index() - - # Should call create with overwrite=False, drop=False - mock_index.create.assert_called_once_with(overwrite=False, drop=False) - - # Validates schema compatibility when index exists and overwrite=False - async def test_ensure_index_schema_validation_success(self, mock_index: AsyncMock, patch_index_from_dict): # noqa: ARG002 - mock_index.exists = AsyncMock(return_value=True) - provider = RedisProvider(user_id="u1", overwrite_index=False) - - # Mock existing index with matching schema - expected_schema = provider.schema_dict - patch_index_from_dict.from_existing.return_value.schema.to_dict.return_value = expected_schema - - await provider._ensure_index() - - # Should validate schema and proceed to create - patch_index_from_dict.from_existing.assert_called_once_with("context", redis_url="redis://localhost:6379") - mock_index.create.assert_called_once_with(overwrite=False, drop=False) - - # Raises ServiceInitializationError when schemas don't match - async def test_ensure_index_schema_validation_failure(self, mock_index: AsyncMock, patch_index_from_dict): # noqa: ARG002 - mock_index.exists = AsyncMock(return_value=True) - provider = RedisProvider(user_id="u1", overwrite_index=False) - - # Override the mock to return a different schema after provider is created - async def mock_from_existing_different(index_name, redis_url): - mock_existing = AsyncMock() - mock_existing.schema.to_dict = MagicMock(return_value={"different": "schema"}) - return mock_existing - - patch_index_from_dict.from_existing = AsyncMock(side_effect=mock_from_existing_different) - - with pytest.raises(ServiceInitializationError) as exc: - await provider._ensure_index() - - assert "incompatible with the current configuration" in str(exc.value) - assert "overwrite_index=True" in str(exc.value) - - # Should not call create when schema validation fails - mock_index.create.assert_not_called() From 6e9d6df04ba6a74d27823b617dc79ba60436044f Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Feb 2026 12:29:29 +0100 Subject: [PATCH 03/28] =?UTF-8?q?refactor:=20update=20all=20sample=20files?= =?UTF-8?q?=20for=20context=20provider=20pipeline=20(AgentThread=E2=86=92A?= =?UTF-8?q?gentSession,=20ContextProvider=E2=86=92BaseContextProvider)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../samples/02-agents/background_responses.py | 12 +- .../chat_client/custom_chat_client.py | 17 +- .../aggregate_context_provider.py | 248 +++++------------- .../azure_ai_with_search_context_agentic.py | 2 +- .../azure_ai_with_search_context_semantic.py | 2 +- .../context_providers/mem0/mem0_basic.py | 16 +- .../context_providers/mem0/mem0_oss.py | 16 +- .../context_providers/mem0/mem0_threads.py | 62 ++--- .../redis/azure_redis_conversation.py | 35 ++- .../context_providers/redis/redis_basics.py | 42 +-- .../redis/redis_conversation.py | 22 +- .../context_providers/redis/redis_threads.py | 64 +++-- .../simple_context_provider.py | 55 ++-- .../custom_chat_message_store_thread.py | 96 ++++--- .../redis_chat_message_store_thread.py | 200 +++++--------- .../conversations/suspend_resume_thread.py | 67 ++--- .../middleware/thread_behavior_middleware.py | 55 ++-- .../observability/agent_observability.py | 4 +- .../agent_with_foundry_tracing.py | 4 +- .../azure_ai_agent_observability.py | 4 +- .../anthropic_claude_with_session.py | 24 +- .../azure_ai_with_existing_conversation.py | 16 +- .../azure_ai/azure_ai_with_hosted_mcp.py | 28 +- .../azure_ai/azure_ai_with_thread.py | 80 +++--- .../azure_ai_with_existing_thread.py | 18 +- .../azure_ai_with_hosted_mcp.py | 16 +- .../azure_ai_with_multiple_tools.py | 16 +- .../azure_ai_with_openapi_tools.py | 10 +- .../azure_ai_agent/azure_ai_with_thread.py | 88 +++---- .../azure_assistants_with_thread.py | 90 +++---- .../azure_chat_client_with_thread.py | 88 +++---- .../azure_responses_client_with_hosted_mcp.py | 64 ++--- .../azure_responses_client_with_thread.py | 82 +++--- .../providers/custom/custom_agent.py | 45 ++-- .../github_copilot_with_session.py | 24 +- .../openai/openai_assistants_with_thread.py | 86 +++--- .../openai/openai_chat_client_with_thread.py | 88 +++---- ...openai_responses_client_with_hosted_mcp.py | 64 ++--- .../openai_responses_client_with_thread.py | 80 +++--- .../function_tool_recover_from_failures.py | 31 ++- ...function_tool_with_approval_and_threads.py | 28 +- .../function_tool_with_max_exceptions.py | 31 ++- .../function_tool_with_max_invocations.py | 31 ++- .../function_tool_with_thread_injection.py | 33 ++- .../azure_ai_agents_with_shared_thread.py | 23 +- .../agents/workflow_as_agent_with_thread.py | 70 +++-- .../workflow_as_agent_checkpoint.py | 23 +- .../function_app.py | 12 +- .../function_app.py | 8 +- .../function_app.py | 8 +- .../function_app.py | 6 +- .../durabletask/01_single_agent/client.py | 8 +- .../durabletask/02_multi_agent/client.py | 12 +- .../03_single_agent_streaming/client.py | 14 +- .../worker.py | 10 +- .../worker.py | 12 +- .../worker.py | 10 +- .../agent_with_text_search_rag/main.py | 34 ++- .../03_assistant_agent_thread_and_stream.py | 18 +- ...03_azure_ai_agent_threads_and_followups.py | 12 +- .../02_chat_completion_with_tool.py | 4 +- .../03_chat_completion_thread_and_stream.py | 8 +- .../01_basic_openai_assistant.py | 2 +- 63 files changed, 1143 insertions(+), 1335 deletions(-) diff --git a/python/samples/02-agents/background_responses.py b/python/samples/02-agents/background_responses.py index 674c2439eb..9c04a59f27 100644 --- a/python/samples/02-agents/background_responses.py +++ b/python/samples/02-agents/background_responses.py @@ -33,12 +33,12 @@ async def non_streaming_polling() -> None: """Demonstrate non-streaming background run with polling.""" print("=== Non-Streaming Polling ===\n") - thread = agent.get_new_thread() + session = agent.create_session() # 2. Start a background run — returns immediately. response = await agent.run( messages="Briefly explain the theory of relativity in two sentences.", - thread=thread, + session=session, options={"background": True}, ) @@ -50,7 +50,7 @@ async def non_streaming_polling() -> None: poll_count += 1 await asyncio.sleep(2) response = await agent.run( - thread=thread, + session=session, options={"continuation_token": response.continuation_token}, ) print(f" Poll {poll_count}: continuation_token={'set' if response.continuation_token else 'None'}") @@ -63,14 +63,14 @@ async def streaming_with_resumption() -> None: """Demonstrate streaming background run with simulated interruption and resumption.""" print("=== Streaming with Resumption ===\n") - thread = agent.get_new_thread() + session = agent.create_session() # 2. Start a streaming background run. last_token = None stream = agent.run( messages="Briefly list three benefits of exercise.", stream=True, - thread=thread, + session=session, options={"background": True}, ) @@ -91,7 +91,7 @@ async def streaming_with_resumption() -> None: print("Resumed stream:") stream = agent.run( stream=True, - thread=thread, + session=session, options={"continuation_token": last_token}, ) async for update in stream: diff --git a/python/samples/02-agents/chat_client/custom_chat_client.py b/python/samples/02-agents/chat_client/custom_chat_client.py index 69228b68ab..b6c69bd0ac 100644 --- a/python/samples/02-agents/chat_client/custom_chat_client.py +++ b/python/samples/02-agents/chat_client/custom_chat_client.py @@ -160,10 +160,10 @@ async def main() -> None: print(chunk.text, end="", flush=True) print() - # Example: Using with threads and conversation history - print("\n--- Using Custom Chat Client with Thread ---") + # Example: Using with sessions and conversation history + print("\n--- Using Custom Chat Client with Session ---") - thread = echo_agent.get_new_thread() + session = echo_agent.create_session() # Multiple messages in conversation messages = [ @@ -173,16 +173,17 @@ async def main() -> None: ] for msg in messages: - result = await echo_agent.run(msg, thread=thread) + result = await echo_agent.run(msg, session=session) print(f"User: {msg}") print(f"Agent: {result.messages[0].text}\n") # Check conversation history - if thread.message_store: - thread_messages = await thread.message_store.list_messages() - print(f"Thread contains {len(thread_messages)} messages") + memory_state = session.state.get("memory", {}) + session_messages = memory_state.get("messages", []) + if session_messages: + print(f"Session contains {len(session_messages)} messages") else: - print("Thread has no message store configured") + print("Session has no messages stored") if __name__ == "__main__": diff --git a/python/samples/02-agents/context_providers/aggregate_context_provider.py b/python/samples/02-agents/context_providers/aggregate_context_provider.py index af3780cfc1..4e5cfb72aa 100644 --- a/python/samples/02-agents/context_providers/aggregate_context_provider.py +++ b/python/samples/02-agents/context_providers/aggregate_context_provider.py @@ -1,216 +1,97 @@ # Copyright (c) Microsoft. All rights reserved. """ -This sample demonstrates how to use an AggregateContextProvider to combine multiple context providers. +This sample demonstrates how to use multiple context providers with an agent. -The AggregateContextProvider is a convenience class that allows you to aggregate multiple -ContextProviders into a single provider. It delegates events to all providers and combines -their context before returning. +Context providers can be passed as a list to the agent's context_providers parameter. +Each provider is called in order during the agent's lifecycle, and their context +is combined automatically. -You can use this implementation as-is, or implement your own aggregation logic. +You can use built-in providers or implement your own by extending BaseContextProvider. """ import asyncio -import sys -from collections.abc import MutableSequence, Sequence -from contextlib import AsyncExitStack -from types import TracebackType -from typing import TYPE_CHECKING, Any, cast +from typing import Any -from agent_framework import Agent, Context, ContextProvider, Message +from agent_framework import Agent, AgentSession, BaseContextProvider, Message, SessionContext from agent_framework.azure import AzureAIClient from azure.identity.aio import AzureCliCredential -if TYPE_CHECKING: - from agent_framework import FunctionTool - -if sys.version_info >= (3, 12): - from typing import override # type: ignore # pragma: no cover -else: - from typing_extensions import override # type: ignore[import] # pragma: no cover -if sys.version_info >= (3, 11): - from typing import Self # pragma: no cover -else: - from typing_extensions import Self # pragma: no cover - - -# region AggregateContextProvider - - -class AggregateContextProvider(ContextProvider): - """A ContextProvider that contains multiple context providers. - - It delegates events to multiple context providers and aggregates responses from those - events before returning. This allows you to combine multiple context providers into a - single provider. - - Examples: - .. code-block:: python - - from agent_framework import Agent - - # Create multiple context providers - provider1 = CustomContextProvider1() - provider2 = CustomContextProvider2() - provider3 = CustomContextProvider3() - - # Combine them using AggregateContextProvider - aggregate = AggregateContextProvider([provider1, provider2, provider3]) - - # Pass the aggregate to the agent - agent = Agent(client=client, name="assistant", context_provider=aggregate) - - # You can also add more providers later - provider4 = CustomContextProvider4() - aggregate.add(provider4) - """ - - def __init__(self, context_providers: ContextProvider | Sequence[ContextProvider] | None = None) -> None: - """Initialize the AggregateContextProvider with context providers. - - Args: - context_providers: The context provider(s) to add. - """ - if isinstance(context_providers, ContextProvider): - self.providers = [context_providers] - else: - self.providers = cast(list[ContextProvider], context_providers) or [] - self._exit_stack: AsyncExitStack | None = None - - def add(self, context_provider: ContextProvider) -> None: - """Add a new context provider. - - Args: - context_provider: The context provider to add. - """ - self.providers.append(context_provider) - - @override - async def thread_created(self, thread_id: str | None = None) -> None: - await asyncio.gather(*[x.thread_created(thread_id) for x in self.providers]) - - @override - async def invoking(self, messages: Message | MutableSequence[Message], **kwargs: Any) -> Context: - contexts = await asyncio.gather(*[provider.invoking(messages, **kwargs) for provider in self.providers]) - instructions: str = "" - return_messages: list[Message] = [] - tools: list["FunctionTool"] = [] - for ctx in contexts: - if ctx.instructions: - instructions += ctx.instructions - if ctx.messages: - return_messages.extend(ctx.messages) - if ctx.tools: - tools.extend(ctx.tools) - return Context(instructions=instructions, messages=return_messages, tools=tools) - - @override - async def invoked( - self, - request_messages: Message | Sequence[Message], - response_messages: Message | Sequence[Message] | None = None, - invoke_exception: Exception | None = None, - **kwargs: Any, - ) -> None: - await asyncio.gather(*[ - x.invoked( - request_messages=request_messages, - response_messages=response_messages, - invoke_exception=invoke_exception, - **kwargs, - ) - for x in self.providers - ]) - - @override - async def __aenter__(self) -> "Self": - """Enter the async context manager and set up all providers. - - Returns: - The AggregateContextProvider instance for chaining. - """ - self._exit_stack = AsyncExitStack() - await self._exit_stack.__aenter__() - - # Enter all context providers - for provider in self.providers: - await self._exit_stack.enter_async_context(provider) - - return self - - @override - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - """Exit the async context manager and clean up all providers. - - Args: - exc_type: The exception type if an exception occurred, None otherwise. - exc_val: The exception value if an exception occurred, None otherwise. - exc_tb: The exception traceback if an exception occurred, None otherwise. - """ - if self._exit_stack is not None: - await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) - self._exit_stack = None - - -# endregion - # region Example Context Providers -class TimeContextProvider(ContextProvider): +class TimeContextProvider(BaseContextProvider): """A simple context provider that adds time-related instructions.""" - @override - async def invoking(self, messages: Message | MutableSequence[Message], **kwargs: Any) -> Context: + def __init__(self): + super().__init__("time") + + async def before_run( + self, + *, + agent: Any, + session: AgentSession | None, + context: SessionContext, + state: dict[str, Any], + ) -> None: from datetime import datetime current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - return Context(instructions=f"The current date and time is: {current_time}. ") + context.extend_instructions(self.source_id, f"The current date and time is: {current_time}. ") -class PersonaContextProvider(ContextProvider): +class PersonaContextProvider(BaseContextProvider): """A context provider that adds a persona to the agent.""" def __init__(self, persona: str): + super().__init__("persona") self.persona = persona - @override - async def invoking(self, messages: Message | MutableSequence[Message], **kwargs: Any) -> Context: - return Context(instructions=f"Your persona: {self.persona}. ") + async def before_run( + self, + *, + agent: Any, + session: AgentSession | None, + context: SessionContext, + state: dict[str, Any], + ) -> None: + context.extend_instructions(self.source_id, f"Your persona: {self.persona}. ") -class PreferencesContextProvider(ContextProvider): +class PreferencesContextProvider(BaseContextProvider): """A context provider that adds user preferences.""" def __init__(self): + super().__init__("preferences") self.preferences: dict[str, str] = {} - @override - async def invoking(self, messages: Message | MutableSequence[Message], **kwargs: Any) -> Context: + async def before_run( + self, + *, + agent: Any, + session: AgentSession | None, + context: SessionContext, + state: dict[str, Any], + ) -> None: if not self.preferences: - return Context() + return prefs_str = ", ".join(f"{k}: {v}" for k, v in self.preferences.items()) - return Context(instructions=f"User preferences: {prefs_str}. ") + context.extend_instructions(self.source_id, f"User preferences: {prefs_str}. ") - @override - async def invoked( + async def after_run( self, - request_messages: Message | Sequence[Message], - response_messages: Message | Sequence[Message] | None = None, - invoke_exception: Exception | None = None, - **kwargs: Any, + *, + agent: Any, + session: AgentSession | None, + context: SessionContext, + state: dict[str, Any], ) -> None: # Simple example: extract and store preferences from user messages # In a real implementation, you might use structured extraction - msgs = [request_messages] if isinstance(request_messages, Message) else list(request_messages) + request_messages = context.get_messages() - for msg in msgs: + for msg in request_messages: content = msg.text if hasattr(msg, "text") else "" # Very simple extraction - in production, use LLM-based extraction if isinstance(content, str) and "prefer" in content.lower() and ":" in content: @@ -228,7 +109,7 @@ async def invoked( async def main(): - """Demonstrate using AggregateContextProvider to combine multiple providers.""" + """Demonstrate using multiple context providers with an agent.""" async with AzureCliCredential() as credential: client = AzureAIClient(credential=credential) @@ -237,35 +118,32 @@ async def main(): persona_provider = PersonaContextProvider("You are a helpful and friendly AI assistant named Max.") preferences_provider = PreferencesContextProvider() - # Combine them using AggregateContextProvider - aggregate_provider = AggregateContextProvider([ - time_provider, - persona_provider, - preferences_provider, - ]) - - # Create the agent with the aggregate provider + # Create the agent with multiple context providers async with Agent( client=client, instructions="You are a helpful assistant.", - context_provider=aggregate_provider, + context_providers=[ + time_provider, + persona_provider, + preferences_provider, + ], ) as agent: - # Create a new thread for the conversation - thread = agent.get_new_thread() + # Create a new session for the conversation + session = agent.create_session() # First message - the agent should include time and persona context print("User: Hello! Who are you?") - result = await agent.run("Hello! Who are you?", thread=thread) + result = await agent.run("Hello! Who are you?", session=session) print(f"Agent: {result}\n") # Set a preference print("User: I prefer language: formal English") - result = await agent.run("I prefer language: formal English", thread=thread) + result = await agent.run("I prefer language: formal English", session=session) print(f"Agent: {result}\n") # Ask something - the agent should now include the preference print("User: Can you tell me a fun fact?") - result = await agent.run("Can you tell me a fun fact?", thread=thread) + result = await agent.run("Can you tell me a fun fact?", session=session) print(f"Agent: {result}\n") # Show what the aggregate provider is tracking diff --git a/python/samples/02-agents/context_providers/azure_ai_search/azure_ai_with_search_context_agentic.py b/python/samples/02-agents/context_providers/azure_ai_search/azure_ai_with_search_context_agentic.py index 7b68265885..5a4503f920 100644 --- a/python/samples/02-agents/context_providers/azure_ai_search/azure_ai_with_search_context_agentic.py +++ b/python/samples/02-agents/context_providers/azure_ai_search/azure_ai_with_search_context_agentic.py @@ -120,7 +120,7 @@ async def main() -> None: "Use the provided context from the knowledge base to answer complex " "questions that may require synthesizing information from multiple sources." ), - context_provider=search_provider, + context_providers=[search_provider], ) as agent, ): print("=== Azure AI Agent with Search Context (Agentic Mode) ===\n") diff --git a/python/samples/02-agents/context_providers/azure_ai_search/azure_ai_with_search_context_semantic.py b/python/samples/02-agents/context_providers/azure_ai_search/azure_ai_with_search_context_semantic.py index 04e26e535e..8309d5197c 100644 --- a/python/samples/02-agents/context_providers/azure_ai_search/azure_ai_with_search_context_semantic.py +++ b/python/samples/02-agents/context_providers/azure_ai_search/azure_ai_with_search_context_semantic.py @@ -76,7 +76,7 @@ async def main() -> None: "You are a helpful assistant. Use the provided context from the " "knowledge base to answer questions accurately." ), - context_provider=search_provider, + context_providers=[search_provider], ) as agent, ): print("=== Azure AI Agent with Search Context (Semantic Mode) ===\n") diff --git a/python/samples/02-agents/context_providers/mem0/mem0_basic.py b/python/samples/02-agents/context_providers/mem0/mem0_basic.py index 1252ee2b49..f7a3a7f91f 100644 --- a/python/samples/02-agents/context_providers/mem0/mem0_basic.py +++ b/python/samples/02-agents/context_providers/mem0/mem0_basic.py @@ -5,7 +5,7 @@ from agent_framework import tool from agent_framework.azure import AzureAIAgentClient -from agent_framework.mem0 import Mem0Provider +from agent_framework.mem0 import Mem0ContextProvider from azure.identity.aio import AzureCliCredential @@ -39,7 +39,7 @@ async def main() -> None: name="FriendlyAssistant", instructions="You are a friendly assistant.", tools=retrieve_company_report, - context_provider=Mem0Provider(user_id=user_id), + context_providers=[Mem0ContextProvider(user_id=user_id)], ) as agent, ): # First ask the agent to retrieve a company report with no previous context. @@ -64,17 +64,17 @@ async def main() -> None: print("Waiting for memories to be processed...") await asyncio.sleep(12) # Empirically determined delay for Mem0 indexing - print("\nRequest within a new thread:") - # Create a new thread for the agent. - # The new thread has no context of the previous conversation. - thread = agent.get_new_thread() + print("\nRequest within a new session:") + # Create a new session for the agent. + # The new session has no context of the previous conversation. + session = agent.create_session() - # Since we have the mem0 component in the thread, the agent should be able to + # Since we have the mem0 component in the session, the agent should be able to # retrieve the company report without asking for clarification, as it will # be able to remember the user preferences from Mem0 component. query = "Please retrieve my company report" print(f"User: {query}") - result = await agent.run(query, thread=thread) + result = await agent.run(query, session=session) print(f"Agent: {result}\n") diff --git a/python/samples/02-agents/context_providers/mem0/mem0_oss.py b/python/samples/02-agents/context_providers/mem0/mem0_oss.py index b22d76a972..2178bbfe58 100644 --- a/python/samples/02-agents/context_providers/mem0/mem0_oss.py +++ b/python/samples/02-agents/context_providers/mem0/mem0_oss.py @@ -5,7 +5,7 @@ from agent_framework import tool from agent_framework.azure import AzureAIAgentClient -from agent_framework.mem0 import Mem0Provider +from agent_framework.mem0 import Mem0ContextProvider from azure.identity.aio import AzureCliCredential from mem0 import AsyncMemory @@ -42,7 +42,7 @@ async def main() -> None: name="FriendlyAssistant", instructions="You are a friendly assistant.", tools=retrieve_company_report, - context_provider=Mem0Provider(user_id=user_id, mem0_client=local_mem0_client), + context_providers=[Mem0ContextProvider(user_id=user_id, mem0_client=local_mem0_client)], ) as agent, ): # First ask the agent to retrieve a company report with no previous context. @@ -60,18 +60,18 @@ async def main() -> None: result = await agent.run(query) print(f"Agent: {result}\n") - print("\nRequest within a new thread:") + print("\nRequest within a new session:") - # Create a new thread for the agent. - # The new thread has no context of the previous conversation. - thread = agent.get_new_thread() + # Create a new session for the agent. + # The new session has no context of the previous conversation. + session = agent.create_session() - # Since we have the mem0 component in the thread, the agent should be able to + # Since we have the mem0 component in the session, the agent should be able to # retrieve the company report without asking for clarification, as it will # be able to remember the user preferences from Mem0 component. query = "Please retrieve my company report" print(f"User: {query}") - result = await agent.run(query, thread=thread) + result = await agent.run(query, session=session) print(f"Agent: {result}\n") diff --git a/python/samples/02-agents/context_providers/mem0/mem0_threads.py b/python/samples/02-agents/context_providers/mem0/mem0_threads.py index 2e564d708c..dd657b4e1d 100644 --- a/python/samples/02-agents/context_providers/mem0/mem0_threads.py +++ b/python/samples/02-agents/context_providers/mem0/mem0_threads.py @@ -5,7 +5,7 @@ from agent_framework import tool from agent_framework.azure import AzureAIAgentClient -from agent_framework.mem0 import Mem0Provider +from agent_framework.mem0 import Mem0ContextProvider from azure.identity.aio import AzureCliCredential @@ -34,11 +34,11 @@ async def example_global_thread_scope() -> None: name="GlobalMemoryAssistant", instructions="You are an assistant that remembers user preferences across conversations.", tools=get_user_preferences, - context_provider=Mem0Provider( + context_providers=[Mem0ContextProvider( user_id=user_id, thread_id=global_thread_id, - scope_to_per_operation_thread_id=False, # Share memories across all threads - ), + scope_to_per_operation_thread_id=False, # Share memories across all sessions + )], ) as global_agent, ): # Store some preferences in the global scope @@ -47,19 +47,19 @@ async def example_global_thread_scope() -> None: result = await global_agent.run(query) print(f"Agent: {result}\n") - # Create a new thread - but memories should still be accessible due to global scope - new_thread = global_agent.get_new_thread() + # Create a new session - but memories should still be accessible due to global scope + new_session = global_agent.create_session() query = "What do you know about my preferences?" - print(f"User (new thread): {query}") - result = await global_agent.run(query, thread=new_thread) + print(f"User (new session): {query}") + result = await global_agent.run(query, session=new_session) print(f"Agent: {result}\n") async def example_per_operation_thread_scope() -> None: - """Example 2: Per-operation thread scope (memories isolated per thread). + """Example 2: Per-operation thread scope (memories isolated per session). - Note: When scope_to_per_operation_thread_id=True, the provider is bound to a single thread - throughout its lifetime. Use the same thread object for all operations with that provider. + Note: When scope_to_per_operation_thread_id=True, the provider is bound to a single session + throughout its lifetime. Use the same session object for all operations with that provider. """ print("2. Per-Operation Thread Scope Example:") print("-" * 40) @@ -72,37 +72,37 @@ async def example_per_operation_thread_scope() -> None: name="ScopedMemoryAssistant", instructions="You are an assistant with thread-scoped memory.", tools=get_user_preferences, - context_provider=Mem0Provider( + context_providers=[Mem0ContextProvider( user_id=user_id, - scope_to_per_operation_thread_id=True, # Isolate memories per thread - ), + scope_to_per_operation_thread_id=True, # Isolate memories per session + )], ) as scoped_agent, ): - # Create a specific thread for this scoped provider - dedicated_thread = scoped_agent.get_new_thread() + # Create a specific session for this scoped provider + dedicated_session = scoped_agent.create_session() - # Store some information in the dedicated thread + # Store some information in the dedicated session query = "Remember that for this conversation, I'm working on a Python project about data analysis." - print(f"User (dedicated thread): {query}") - result = await scoped_agent.run(query, thread=dedicated_thread) + print(f"User (dedicated session): {query}") + result = await scoped_agent.run(query, session=dedicated_session) print(f"Agent: {result}\n") - # Test memory retrieval in the same dedicated thread + # Test memory retrieval in the same dedicated session query = "What project am I working on?" - print(f"User (same dedicated thread): {query}") - result = await scoped_agent.run(query, thread=dedicated_thread) + print(f"User (same dedicated session): {query}") + result = await scoped_agent.run(query, session=dedicated_session) print(f"Agent: {result}\n") - # Store more information in the same thread + # Store more information in the same session query = "Also remember that I prefer using pandas and matplotlib for this project." - print(f"User (same dedicated thread): {query}") - result = await scoped_agent.run(query, thread=dedicated_thread) + print(f"User (same dedicated session): {query}") + result = await scoped_agent.run(query, session=dedicated_session) print(f"Agent: {result}\n") # Test comprehensive memory retrieval query = "What do you know about my current project and preferences?" - print(f"User (same dedicated thread): {query}") - result = await scoped_agent.run(query, thread=dedicated_thread) + print(f"User (same dedicated session): {query}") + result = await scoped_agent.run(query, session=dedicated_session) print(f"Agent: {result}\n") @@ -119,16 +119,16 @@ async def example_multiple_agents() -> None: AzureAIAgentClient(credential=credential).as_agent( name="PersonalAssistant", instructions="You are a personal assistant that helps with personal tasks.", - context_provider=Mem0Provider( + context_providers=[Mem0ContextProvider( agent_id=agent_id_1, - ), + )], ) as personal_agent, AzureAIAgentClient(credential=credential).as_agent( name="WorkAssistant", instructions="You are a work assistant that helps with professional tasks.", - context_provider=Mem0Provider( + context_providers=[Mem0ContextProvider( agent_id=agent_id_2, - ), + )], ) as work_agent, ): # Store personal information diff --git a/python/samples/02-agents/context_providers/redis/azure_redis_conversation.py b/python/samples/02-agents/context_providers/redis/azure_redis_conversation.py index 5c300abcbf..ce569be8cb 100644 --- a/python/samples/02-agents/context_providers/redis/azure_redis_conversation.py +++ b/python/samples/02-agents/context_providers/redis/azure_redis_conversation.py @@ -1,9 +1,9 @@ # Copyright (c) Microsoft. All rights reserved. -"""Azure Managed Redis Chat Message Store with Azure AD Authentication +"""Azure Managed Redis History Provider with Azure AD Authentication This example demonstrates how to use Azure Managed Redis with Azure AD authentication -to persist conversational details using RedisChatMessageStore. +to persist conversational details using RedisHistoryProvider. Requirements: - Azure Managed Redis instance with Azure AD authentication enabled @@ -22,7 +22,7 @@ import os from agent_framework.openai import OpenAIChatClient -from agent_framework.redis import RedisChatMessageStore +from agent_framework.redis import RedisHistoryProvider from azure.identity.aio import AzureCliCredential from redis.credentials import CredentialProvider @@ -60,28 +60,27 @@ async def main() -> None: azure_credential = AzureCliCredential() credential_provider = AzureCredentialProvider(azure_credential, user_object_id) - thread_id = "azure_test_thread" - - # Factory for creating Azure Redis chat message store - def chat_message_store_factory(): - return RedisChatMessageStore( - credential_provider=credential_provider, - host=redis_host, - port=10000, - ssl=True, - thread_id=thread_id, - key_prefix="chat_messages", - max_messages=100, - ) + session_id = "azure_test_session" + + # Create Azure Redis history provider + history_provider = RedisHistoryProvider( + credential_provider=credential_provider, + host=redis_host, + port=10000, + ssl=True, + thread_id=session_id, + key_prefix="chat_messages", + max_messages=100, + ) # Create chat client client = OpenAIChatClient() - # Create agent with Azure Redis store + # Create agent with Azure Redis history provider agent = client.as_agent( name="AzureRedisAssistant", instructions="You are a helpful assistant.", - chat_message_store_factory=chat_message_store_factory, + context_providers=[history_provider], ) # Conversation diff --git a/python/samples/02-agents/context_providers/redis/redis_basics.py b/python/samples/02-agents/context_providers/redis/redis_basics.py index 5dfcbec850..ba038096db 100644 --- a/python/samples/02-agents/context_providers/redis/redis_basics.py +++ b/python/samples/02-agents/context_providers/redis/redis_basics.py @@ -32,7 +32,7 @@ from agent_framework import Message, tool from agent_framework.openai import OpenAIChatClient -from agent_framework_redis._provider import RedisProvider +from agent_framework_redis._context_provider import RedisContextProvider from redisvl.extensions.cache.embeddings import EmbeddingsCache from redisvl.utils.vectorize import OpenAITextVectorizer @@ -104,7 +104,7 @@ async def main() -> None: # Recommend default for OPENAI_CHAT_MODEL_ID is gpt-4o-mini # We attach an embedding vectorizer so the provider can perform hybrid (text + vector) - # retrieval. If you prefer text-only retrieval, instantiate RedisProvider without the + # retrieval. If you prefer text-only retrieval, instantiate RedisContextProvider without the # 'vectorizer' and vector_* parameters. vectorizer = OpenAITextVectorizer( model="text-embedding-ada-002", @@ -114,7 +114,7 @@ async def main() -> None: # The provider manages persistence and retrieval. application_id/agent_id/user_id # scope data for multi-tenant separation; thread_id (set later) narrows to a # specific conversation. - provider = RedisProvider( + provider = RedisContextProvider( redis_url="redis://localhost:6379", index_name="redis_basics", application_id="matrix_of_kermits", @@ -133,21 +133,27 @@ async def main() -> None: Message("system", ["runA CONVO: System Message"]), ] - # Declare/start a conversation/thread and write messages under 'runA'. - # Threads are logical boundaries used by the provider to group and retrieve - # conversation-specific context. - await provider.thread_created(thread_id="runA") - await provider.invoked(request_messages=messages) + # Use the provider's before_run/after_run API to store and retrieve messages. + # In practice, the agent handles this automatically; this shows the low-level API. + from agent_framework import AgentSession, SessionContext - # Retrieve relevant memories for a hypothetical model call. The provider uses - # the current request messages as the retrieval query and returns context to - # be injected into the model's instructions. - ctx = await provider.invoking([Message("system", ["B: Assistant Message"])]) + session = AgentSession(session_id="runA") + context = SessionContext() + context.extend_messages("input", messages) + state = session.state + + # Store messages via after_run + await provider.after_run(agent=None, session=session, context=context, state=state) + + # Retrieve relevant memories via before_run + query_context = SessionContext() + query_context.extend_messages("input", [Message("system", ["B: Assistant Message"])]) + await provider.before_run(agent=None, session=session, context=query_context, state=state) # Inspect retrieved memories that would be injected into instructions # (Debug-only output so you can verify retrieval works as expected.) - print("Model Invoking Result:") - print(ctx) + print("Before Run Result:") + print(query_context) # Drop / delete the provider index in Redis await provider.redis_index.delete() @@ -163,7 +169,7 @@ async def main() -> None: cache=EmbeddingsCache(name="openai_embeddings_cache", redis_url="redis://localhost:6379"), ) # Recreate a clean index so the next scenario starts fresh - provider = RedisProvider( + provider = RedisContextProvider( redis_url="redis://localhost:6379", index_name="redis_basics_2", prefix="context_2", @@ -187,7 +193,7 @@ async def main() -> None: "Before answering, always check for stored context" ), tools=[], - context_provider=provider, + context_providers=[provider], ) # Teach a user preference; the agent writes this to the provider's memory @@ -210,7 +216,7 @@ async def main() -> None: print("\n3. Agent + provider + tool: store and recall tool-derived context") print("-" * 40) # Text-only provider (full-text search only). Omits vectorizer and related params. - provider = RedisProvider( + provider = RedisContextProvider( redis_url="redis://localhost:6379", index_name="redis_basics_3", prefix="context_3", @@ -229,7 +235,7 @@ async def main() -> None: "Before answering, always check for stored context" ), tools=search_flights, - context_provider=provider, + context_providers=[provider], ) # Invoke the tool; outputs become part of memory/context query = "Are there any flights from new york city (jfk) to la? Give me details" diff --git a/python/samples/02-agents/context_providers/redis/redis_conversation.py b/python/samples/02-agents/context_providers/redis/redis_conversation.py index f202a0cd2c..6de659aba3 100644 --- a/python/samples/02-agents/context_providers/redis/redis_conversation.py +++ b/python/samples/02-agents/context_providers/redis/redis_conversation.py @@ -2,7 +2,7 @@ """Redis Context Provider: Basic usage and agent integration -This example demonstrates how to use the Redis ChatMessageStoreProtocol to persist +This example demonstrates how to use the Redis context provider to persist conversational details. Pass it as a constructor argument to create_agent. Requirements: @@ -18,8 +18,7 @@ import os from agent_framework.openai import OpenAIChatClient -from agent_framework_redis._chat_message_store import RedisChatMessageStore -from agent_framework_redis._provider import RedisProvider +from agent_framework_redis._context_provider import RedisContextProvider from redisvl.extensions.cache.embeddings import EmbeddingsCache from redisvl.utils.vectorize import OpenAITextVectorizer @@ -37,9 +36,9 @@ async def main() -> None: cache=EmbeddingsCache(name="openai_embeddings_cache", redis_url="redis://localhost:6379"), ) - thread_id = "test_thread" + session_id = "test_session" - provider = RedisProvider( + provider = RedisContextProvider( redis_url="redis://localhost:6379", index_name="redis_conversation", prefix="redis_conversation", @@ -50,17 +49,9 @@ async def main() -> None: vector_field_name="vector", vector_algorithm="hnsw", vector_distance_metric="cosine", - thread_id=thread_id, + thread_id=session_id, ) - def chat_message_store_factory(): - return RedisChatMessageStore( - redis_url="redis://localhost:6379", - thread_id=thread_id, - key_prefix="chat_messages", - max_messages=100, - ) - # Create chat client for the agent client = OpenAIChatClient(model_id=os.getenv("OPENAI_CHAT_MODEL_ID"), api_key=os.getenv("OPENAI_API_KEY")) # Create agent wired to the Redis context provider. The provider automatically @@ -72,8 +63,7 @@ def chat_message_store_factory(): "Before answering, always check for stored context" ), tools=[], - context_provider=provider, - chat_message_store_factory=chat_message_store_factory, + context_providers=[provider], ) # Teach a user preference; the agent writes this to the provider's memory diff --git a/python/samples/02-agents/context_providers/redis/redis_threads.py b/python/samples/02-agents/context_providers/redis/redis_threads.py index 2347281bf5..c11823dfb6 100644 --- a/python/samples/02-agents/context_providers/redis/redis_threads.py +++ b/python/samples/02-agents/context_providers/redis/redis_threads.py @@ -31,7 +31,7 @@ import uuid from agent_framework.openai import OpenAIChatClient -from agent_framework_redis._provider import RedisProvider +from agent_framework_redis._context_provider import RedisContextProvider from redisvl.extensions.cache.embeddings import EmbeddingsCache from redisvl.utils.vectorize import OpenAITextVectorizer @@ -51,16 +51,14 @@ async def example_global_thread_scope() -> None: api_key=os.getenv("OPENAI_API_KEY"), ) - provider = RedisProvider( + provider = RedisContextProvider( redis_url="redis://localhost:6379", index_name="redis_threads_global", - # overwrite_redis_index=True, - # drop_redis_index=True, application_id="threads_demo_app", agent_id="threads_demo_agent", user_id="threads_demo_user", thread_id=global_thread_id, - scope_to_per_operation_thread_id=False, # Share memories across all threads + scope_to_per_operation_thread_id=False, # Share memories across all sessions ) agent = client.as_agent( @@ -70,7 +68,7 @@ async def example_global_thread_scope() -> None: "Before answering, always check for stored context containing information" ), tools=[], - context_provider=provider, + context_providers=[provider], ) # Store a preference in the global scope @@ -79,11 +77,11 @@ async def example_global_thread_scope() -> None: result = await agent.run(query) print(f"Agent: {result}\n") - # Create a new thread - memories should still be accessible due to global scope - new_thread = agent.get_new_thread() + # Create a new session - memories should still be accessible due to global scope + new_session = agent.create_session() query = "What technical responses do I prefer?" - print(f"User (new thread): {query}") - result = await agent.run(query, thread=new_thread) + print(f"User (new session): {query}") + result = await agent.run(query, session=new_session) print(f"Agent: {result}\n") # Clean up the Redis index @@ -91,10 +89,10 @@ async def example_global_thread_scope() -> None: async def example_per_operation_thread_scope() -> None: - """Example 2: Per-operation thread scope (memories isolated per thread). + """Example 2: Per-operation thread scope (memories isolated per session). - Note: When scope_to_per_operation_thread_id=True, the provider is bound to a single thread - throughout its lifetime. Use the same thread object for all operations with that provider. + Note: When scope_to_per_operation_thread_id=True, the provider is bound to a single session + throughout its lifetime. Use the same session object for all operations with that provider. """ print("2. Per-Operation Thread Scope Example:") print("-" * 40) @@ -110,7 +108,7 @@ async def example_per_operation_thread_scope() -> None: cache=EmbeddingsCache(name="openai_embeddings_cache", redis_url="redis://localhost:6379"), ) - provider = RedisProvider( + provider = RedisContextProvider( redis_url="redis://localhost:6379", index_name="redis_threads_dynamic", # overwrite_redis_index=True, @@ -118,7 +116,7 @@ async def example_per_operation_thread_scope() -> None: application_id="threads_demo_app", agent_id="threads_demo_agent", user_id="threads_demo_user", - scope_to_per_operation_thread_id=True, # Isolate memories per thread + scope_to_per_operation_thread_id=True, # Isolate memories per session redis_vectorizer=vectorizer, vector_field_name="vector", vector_algorithm="hnsw", @@ -128,34 +126,34 @@ async def example_per_operation_thread_scope() -> None: agent = client.as_agent( name="ScopedMemoryAssistant", instructions="You are an assistant with thread-scoped memory.", - context_provider=provider, + context_providers=[provider], ) - # Create a specific thread for this scoped provider - dedicated_thread = agent.get_new_thread() + # Create a specific session for this scoped provider + dedicated_session = agent.create_session() - # Store some information in the dedicated thread + # Store some information in the dedicated session query = "Remember that for this conversation, I'm working on a Python project about data analysis." - print(f"User (dedicated thread): {query}") - result = await agent.run(query, thread=dedicated_thread) + print(f"User (dedicated session): {query}") + result = await agent.run(query, session=dedicated_session) print(f"Agent: {result}\n") - # Test memory retrieval in the same dedicated thread + # Test memory retrieval in the same dedicated session query = "What project am I working on?" - print(f"User (same dedicated thread): {query}") - result = await agent.run(query, thread=dedicated_thread) + print(f"User (same dedicated session): {query}") + result = await agent.run(query, session=dedicated_session) print(f"Agent: {result}\n") - # Store more information in the same thread + # Store more information in the same session query = "Also remember that I prefer using pandas and matplotlib for this project." - print(f"User (same dedicated thread): {query}") - result = await agent.run(query, thread=dedicated_thread) + print(f"User (same dedicated session): {query}") + result = await agent.run(query, session=dedicated_session) print(f"Agent: {result}\n") # Test comprehensive memory retrieval query = "What do you know about my current project and preferences?" - print(f"User (same dedicated thread): {query}") - result = await agent.run(query, thread=dedicated_thread) + print(f"User (same dedicated session): {query}") + result = await agent.run(query, session=dedicated_session) print(f"Agent: {result}\n") # Clean up the Redis index @@ -178,7 +176,7 @@ async def example_multiple_agents() -> None: cache=EmbeddingsCache(name="openai_embeddings_cache", redis_url="redis://localhost:6379"), ) - personal_provider = RedisProvider( + personal_provider = RedisContextProvider( redis_url="redis://localhost:6379", index_name="redis_threads_agents", application_id="threads_demo_app", @@ -193,10 +191,10 @@ async def example_multiple_agents() -> None: personal_agent = client.as_agent( name="PersonalAssistant", instructions="You are a personal assistant that helps with personal tasks.", - context_provider=personal_provider, + context_providers=[personal_provider], ) - work_provider = RedisProvider( + work_provider = RedisContextProvider( redis_url="redis://localhost:6379", index_name="redis_threads_agents", application_id="threads_demo_app", @@ -211,7 +209,7 @@ async def example_multiple_agents() -> None: work_agent = client.as_agent( name="WorkAssistant", instructions="You are a work assistant that helps with professional tasks.", - context_provider=work_provider, + context_providers=[work_provider], ) # Store personal information diff --git a/python/samples/02-agents/context_providers/simple_context_provider.py b/python/samples/02-agents/context_providers/simple_context_provider.py index e151651199..940e6a057e 100644 --- a/python/samples/02-agents/context_providers/simple_context_provider.py +++ b/python/samples/02-agents/context_providers/simple_context_provider.py @@ -1,10 +1,9 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -from collections.abc import MutableSequence, Sequence from typing import Any -from agent_framework import Agent, Context, ContextProvider, Message, SupportsChatGetResponse +from agent_framework import Agent, AgentSession, BaseContextProvider, Message, SessionContext, SupportsChatGetResponse from agent_framework.azure import AzureAIClient from azure.identity.aio import AzureCliCredential from pydantic import BaseModel @@ -15,13 +14,13 @@ class UserInfo(BaseModel): age: int | None = None -class UserInfoMemory(ContextProvider): +class UserInfoMemory(BaseContextProvider): def __init__(self, client: SupportsChatGetResponse, user_info: UserInfo | None = None, **kwargs: Any): """Create the memory. If you pass in kwargs, they will be attempted to be used to create a UserInfo object. """ - + super().__init__("user-info-memory") self._chat_client = client if user_info: self.user_info = user_info @@ -30,14 +29,16 @@ def __init__(self, client: SupportsChatGetResponse, user_info: UserInfo | None = else: self.user_info = UserInfo() - async def invoked( + async def after_run( self, - request_messages: Message | Sequence[Message], - response_messages: Message | Sequence[Message] | None = None, - invoke_exception: Exception | None = None, - **kwargs: Any, + *, + agent: Any, + session: AgentSession | None, + context: SessionContext, + state: dict[str, Any], ) -> None: """Extract user information from messages after each agent call.""" + request_messages = context.get_messages() # Check if we need to extract user info from user messages user_messages = [msg for msg in request_messages if hasattr(msg, "role") and msg.role == "user"] # type: ignore @@ -64,7 +65,14 @@ async def invoked( except Exception: pass # Failed to extract, continue without updating - async def invoking(self, messages: Message | MutableSequence[Message], **kwargs: Any) -> Context: + async def before_run( + self, + *, + agent: Any, + session: AgentSession | None, + context: SessionContext, + state: dict[str, Any], + ) -> None: """Provide user information context before each agent call.""" instructions: list[str] = [] @@ -82,11 +90,11 @@ async def invoking(self, messages: Message | MutableSequence[Message], **kwargs: else: instructions.append(f"The user's age is {self.user_info.age}.") - # Return context with additional instructions - return Context(instructions=" ".join(instructions)) + # Add context with additional instructions + context.extend_instructions(self.source_id, " ".join(instructions)) def serialize(self) -> str: - """Serialize the user info for thread persistence.""" + """Serialize the user info for session persistence.""" return self.user_info.model_dump_json() @@ -101,21 +109,20 @@ async def main(): async with Agent( client=client, instructions="You are a friendly assistant. Always address the user by their name.", - context_provider=memory_provider, + context_providers=[memory_provider], ) as agent: - # Create a new thread for the conversation - thread = agent.get_new_thread() + # Create a new session for the conversation + session = agent.create_session() - print(await agent.run("Hello, what is the square root of 9?", thread=thread)) - print(await agent.run("My name is Ruaidhrí", thread=thread)) - print(await agent.run("I am 20 years old", thread=thread)) + print(await agent.run("Hello, what is the square root of 9?", session=session)) + print(await agent.run("My name is Ruaidhrí", session=session)) + print(await agent.run("I am 20 years old", session=session)) - # Access the memory component via the thread's get_service method and inspect the memories - user_info_memory = thread.context_provider.providers[0] # type: ignore - if user_info_memory: + # Access the memory component and inspect the memories + if memory_provider: print() - print(f"MEMORY - User Name: {user_info_memory.user_info.name}") # type: ignore - print(f"MEMORY - User Age: {user_info_memory.user_info.age}") # type: ignore + print(f"MEMORY - User Name: {memory_provider.user_info.name}") + print(f"MEMORY - User Age: {memory_provider.user_info.age}") if __name__ == "__main__": diff --git a/python/samples/02-agents/conversations/custom_chat_message_store_thread.py b/python/samples/02-agents/conversations/custom_chat_message_store_thread.py index b5ab03bbcb..9470e7bbb2 100644 --- a/python/samples/02-agents/conversations/custom_chat_message_store_thread.py +++ b/python/samples/02-agents/conversations/custom_chat_message_store_thread.py @@ -1,92 +1,84 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -from collections.abc import Collection +from collections.abc import Sequence from typing import Any -from agent_framework import ChatMessageStoreProtocol, Message -from agent_framework._threads import ChatMessageStoreState +from agent_framework import AgentSession, BaseHistoryProvider, Message, SessionContext from agent_framework.openai import OpenAIChatClient """ -Custom Chat Message Store Thread Example +Custom History Provider Example -This sample demonstrates how to implement and use a custom chat message store -for thread management, allowing you to persist conversation history in your +This sample demonstrates how to implement and use a custom history provider +for session management, allowing you to persist conversation history in your preferred storage solution (database, file system, etc.). """ -class CustomChatMessageStore(ChatMessageStoreProtocol): - """Implementation of custom chat message store. +class CustomHistoryProvider(BaseHistoryProvider): + """Implementation of custom history provider. In real applications, this can be an implementation of relational database or vector store.""" - def __init__(self, messages: Collection[Message] | None = None) -> None: - self._messages: list[Message] = [] - if messages: - self._messages.extend(messages) - - async def add_messages(self, messages: Collection[Message]) -> None: - self._messages.extend(messages) - - async def list_messages(self) -> list[Message]: - return self._messages - - @classmethod - async def deserialize(cls, serialized_store_state: Any, **kwargs: Any) -> "CustomChatMessageStore": - """Create a new instance from serialized state.""" - store = cls() - await store.update_from_state(serialized_store_state, **kwargs) - return store - - async def update_from_state(self, serialized_store_state: Any, **kwargs: Any) -> None: - """Update this instance from serialized state.""" - if serialized_store_state: - state = ChatMessageStoreState.from_dict(serialized_store_state, **kwargs) - if state.messages: - self._messages.extend(state.messages) - - async def serialize(self, **kwargs: Any) -> Any: - """Serialize this store's state.""" - state = ChatMessageStoreState(messages=self._messages) - return state.to_dict(**kwargs) + def __init__(self) -> None: + super().__init__("custom-history") + self._storage: dict[str, list[Message]] = {} + + async def get_messages( + self, session_id: str | None, *, state: dict[str, Any] | None = None, **kwargs: Any + ) -> list[Message]: + key = session_id or "default" + return list(self._storage.get(key, [])) + + async def save_messages( + self, + session_id: str | None, + messages: Sequence[Message], + *, + state: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + key = session_id or "default" + if key not in self._storage: + self._storage[key] = [] + self._storage[key].extend(messages) async def main() -> None: - """Demonstrates how to use 3rd party or custom chat message store for threads.""" - print("=== Thread with 3rd party or custom chat message store ===") + """Demonstrates how to use 3rd party or custom history provider for sessions.""" + print("=== Session with 3rd party or custom history provider ===") # OpenAI Chat Client is used as an example here, # other chat clients can be used as well. agent = OpenAIChatClient().as_agent( name="CustomBot", instructions="You are a helpful assistant that remembers our conversation.", - # Use custom chat message store. - # If not provided, the default in-memory store will be used. - chat_message_store_factory=CustomChatMessageStore, + # Use custom history provider. + # If not provided, the default in-memory provider will be used. + context_providers=[CustomHistoryProvider()], ) - # Start a new thread for the agent conversation. - thread = agent.get_new_thread() + # Start a new session for the agent conversation. + session = agent.create_session() # Respond to user input. query = "Hello! My name is Alice and I love pizza." print(f"User: {query}") - print(f"Agent: {await agent.run(query, thread=thread)}\n") + print(f"Agent: {await agent.run(query, session=session)}\n") - # Serialize the thread state, so it can be stored for later use. - serialized_thread = await thread.serialize() + # Serialize the session state, so it can be stored for later use. + serialized_session = session.to_dict() - # The thread can now be saved to a database, file, or any other storage mechanism and loaded again later. - print(f"Serialized thread: {serialized_thread}\n") + # The session can now be saved to a database, file, or any other storage mechanism and loaded again later. + print(f"Serialized session: {serialized_session}\n") - # Deserialize the thread state after loading from storage. - resumed_thread = await agent.deserialize_thread(serialized_thread) + # Deserialize the session state after loading from storage. + resumed_session = AgentSession.from_dict(serialized_session) # Respond to user input. query = "What do you remember about me?" print(f"User: {query}") - print(f"Agent: {await agent.run(query, thread=resumed_thread)}\n") + print(f"Agent: {await agent.run(query, session=resumed_session)}\n") if __name__ == "__main__": diff --git a/python/samples/02-agents/conversations/redis_chat_message_store_thread.py b/python/samples/02-agents/conversations/redis_chat_message_store_thread.py index 217355eb72..5f1a0371f8 100644 --- a/python/samples/02-agents/conversations/redis_chat_message_store_thread.py +++ b/python/samples/02-agents/conversations/redis_chat_message_store_thread.py @@ -4,60 +4,51 @@ import os from uuid import uuid4 -from agent_framework import AgentThread +from agent_framework import AgentSession from agent_framework.openai import OpenAIChatClient -from agent_framework.redis import RedisChatMessageStore +from agent_framework.redis import RedisHistoryProvider """ -Redis Chat Message Store Thread Example +Redis History Provider Session Example -This sample demonstrates how to use Redis as a chat message store for thread +This sample demonstrates how to use Redis as a history provider for session management, enabling persistent conversation history storage across sessions with Redis as the backend data store. """ async def example_manual_memory_store() -> None: - """Basic example of using Redis chat message store.""" - print("=== Basic Redis Chat Message Store Example ===") + """Basic example of using Redis history provider.""" + print("=== Basic Redis History Provider Example ===") - # Create Redis store with auto-generated thread ID - redis_store = RedisChatMessageStore( + # Create Redis history provider + redis_provider = RedisHistoryProvider( redis_url="redis://localhost:6379", - # thread_id will be auto-generated if not provided ) - print(f"Created store with thread ID: {redis_store.thread_id}") - - # Create thread with Redis store - thread = AgentThread(message_store=redis_store) - - # Create agent + # Create agent with Redis history provider agent = OpenAIChatClient().as_agent( name="RedisBot", instructions="You are a helpful assistant that remembers our conversation using Redis.", + context_providers=[redis_provider], ) + # Create session + session = agent.create_session() + # Have a conversation print("\n--- Starting conversation ---") query1 = "Hello! My name is Alice and I love pizza." print(f"User: {query1}") - response1 = await agent.run(query1, thread=thread) + response1 = await agent.run(query1, session=session) print(f"Agent: {response1.text}") query2 = "What do you remember about me?" print(f"User: {query2}") - response2 = await agent.run(query2, thread=thread) + response2 = await agent.run(query2, session=session) print(f"Agent: {response2.text}") - # Show messages are stored in Redis - messages = await redis_store.list_messages() - print(f"\nTotal messages in Redis: {len(messages)}") - - # Cleanup - await redis_store.clear() - await redis_store.aclose() - print("Cleaned up Redis data\n") + print("Done\n") async def example_user_session_management() -> None: @@ -67,27 +58,23 @@ async def example_user_session_management() -> None: user_id = "alice_123" session_id = f"session_{uuid4()}" - # Create Redis store for specific user session - def create_user_session_store(): - return RedisChatMessageStore( - redis_url="redis://localhost:6379", - thread_id=f"user_{user_id}_{session_id}", - max_messages=10, # Keep only last 10 messages - ) + # Create Redis history provider for specific user session + redis_provider = RedisHistoryProvider( + redis_url="redis://localhost:6379", + max_messages=10, # Keep only last 10 messages + ) - # Create agent with factory pattern + # Create agent with history provider agent = OpenAIChatClient().as_agent( name="SessionBot", instructions="You are a helpful assistant. Keep track of user preferences.", - chat_message_store_factory=create_user_session_store, + context_providers=[redis_provider], ) # Start conversation - thread = agent.get_new_thread() + session = agent.create_session() print(f"Started session for user {user_id}") - if hasattr(thread.message_store, "thread_id"): - print(f"Thread ID: {thread.message_store.thread_id}") # type: ignore[union-attr] # Simulate conversation queries = [ @@ -100,152 +87,120 @@ def create_user_session_store(): for i, query in enumerate(queries, 1): print(f"\n--- Message {i} ---") print(f"User: {query}") - response = await agent.run(query, thread=thread) + response = await agent.run(query, session=session) print(f"Agent: {response.text}") - # Show persistent storage - if thread.message_store: - messages = await thread.message_store.list_messages() # type: ignore[union-attr] - print(f"\nMessages stored for user {user_id}: {len(messages)}") - - # Cleanup - if thread.message_store: - await thread.message_store.clear() # type: ignore[union-attr] - await thread.message_store.aclose() # type: ignore[union-attr] - print("Cleaned up session data\n") + print("Done\n") async def example_conversation_persistence() -> None: """Example of conversation persistence across application restarts.""" print("=== Conversation Persistence Example ===") - conversation_id = "persistent_chat_001" - # Phase 1: Start conversation print("--- Phase 1: Starting conversation ---") - store1 = RedisChatMessageStore( + redis_provider = RedisHistoryProvider( redis_url="redis://localhost:6379", - thread_id=conversation_id, ) - thread1 = AgentThread(message_store=store1) agent = OpenAIChatClient().as_agent( name="PersistentBot", instructions="You are a helpful assistant. Remember our conversation history.", + context_providers=[redis_provider], ) + session = agent.create_session() + # Start conversation query1 = "Hello! I'm working on a Python project about machine learning." print(f"User: {query1}") - response1 = await agent.run(query1, thread=thread1) + response1 = await agent.run(query1, session=session) print(f"Agent: {response1.text}") query2 = "I'm specifically interested in neural networks." print(f"User: {query2}") - response2 = await agent.run(query2, thread=thread1) + response2 = await agent.run(query2, session=session) print(f"Agent: {response2.text}") - print(f"Stored {len(await store1.list_messages())} messages in Redis") - await store1.aclose() + # Serialize session state + serialized = session.to_dict() # Phase 2: Resume conversation (simulating app restart) print("\n--- Phase 2: Resuming conversation (after 'restart') ---") - store2 = RedisChatMessageStore( - redis_url="redis://localhost:6379", - thread_id=conversation_id, # Same thread ID - ) - - thread2 = AgentThread(message_store=store2) + restored_session = AgentSession.from_dict(serialized) # Continue conversation - agent should remember context query3 = "What was I working on before?" print(f"User: {query3}") - response3 = await agent.run(query3, thread=thread2) + response3 = await agent.run(query3, session=restored_session) print(f"Agent: {response3.text}") query4 = "Can you suggest some Python libraries for neural networks?" print(f"User: {query4}") - response4 = await agent.run(query4, thread=thread2) + response4 = await agent.run(query4, session=restored_session) print(f"Agent: {response4.text}") - print(f"Total messages after resuming: {len(await store2.list_messages())}") + print("Done\n") - # Cleanup - await store2.clear() - await store2.aclose() - print("Cleaned up persistent data\n") +async def example_session_serialization() -> None: + """Example of session state serialization and deserialization.""" + print("=== Session Serialization Example ===") -async def example_thread_serialization() -> None: - """Example of thread state serialization and deserialization.""" - print("=== Thread Serialization Example ===") - - # Create initial thread with Redis store - original_store = RedisChatMessageStore( + redis_provider = RedisHistoryProvider( redis_url="redis://localhost:6379", - thread_id="serialization_test", - max_messages=50, ) - original_thread = AgentThread(message_store=original_store) - agent = OpenAIChatClient().as_agent( name="SerializationBot", instructions="You are a helpful assistant.", + context_providers=[redis_provider], ) + session = agent.create_session() + # Have initial conversation print("--- Initial conversation ---") query1 = "Hello! I'm testing serialization." print(f"User: {query1}") - response1 = await agent.run(query1, thread=original_thread) + response1 = await agent.run(query1, session=session) print(f"Agent: {response1.text}") - # Serialize thread state - serialized_thread = await original_thread.serialize() - print(f"\nSerialized thread state: {serialized_thread}") + # Serialize session state + serialized = session.to_dict() + print(f"\nSerialized session state: {serialized}") - # Close original connection - await original_store.aclose() + # Deserialize session state (simulating loading from database/file) + print("\n--- Deserializing session state ---") + restored_session = AgentSession.from_dict(serialized) - # Deserialize thread state (simulating loading from database/file) - print("\n--- Deserializing thread state ---") - - # Create a new thread with the same Redis store type - # This ensures the correct store type is used for deserialization - restored_store = RedisChatMessageStore(redis_url="redis://localhost:6379") - restored_thread = await AgentThread.deserialize(serialized_thread, message_store=restored_store) - - # Continue conversation with restored thread + # Continue conversation with restored session query2 = "Do you remember what I said about testing?" print(f"User: {query2}") - response2 = await agent.run(query2, thread=restored_thread) + response2 = await agent.run(query2, session=restored_session) print(f"Agent: {response2.text}") - # Cleanup - if restored_thread.message_store: - await restored_thread.message_store.clear() # type: ignore[union-attr] - await restored_thread.message_store.aclose() # type: ignore[union-attr] - print("Cleaned up serialization test data\n") + print("Done\n") async def example_message_limits() -> None: """Example of automatic message trimming with limits.""" print("=== Message Limits Example ===") - # Create store with small message limit - store = RedisChatMessageStore( + # Create provider with small message limit + redis_provider = RedisHistoryProvider( redis_url="redis://localhost:6379", - thread_id="limits_test", max_messages=3, # Keep only 3 most recent messages ) - thread = AgentThread(message_store=store) agent = OpenAIChatClient().as_agent( name="LimitBot", instructions="You are a helpful assistant with limited memory.", + context_providers=[redis_provider], ) + session = agent.create_session() + # Send multiple messages to test trimming messages = [ "Message 1: Hello!", @@ -258,27 +213,15 @@ async def example_message_limits() -> None: for i, query in enumerate(messages, 1): print(f"\n--- Sending message {i} ---") print(f"User: {query}") - response = await agent.run(query, thread=thread) + response = await agent.run(query, session=session) print(f"Agent: {response.text}") - stored_messages = await store.list_messages() - print(f"Messages in store: {len(stored_messages)}") - if len(stored_messages) > 0: - print(f"Oldest message: {stored_messages[0].text[:30]}...") - - # Final check - final_messages = await store.list_messages() - print(f"\nFinal message count: {len(final_messages)} (should be <= 6: 3 messages × 2 per exchange)") - - # Cleanup - await store.clear() - await store.aclose() - print("Cleaned up limits test data\n") + print("Done\n") async def main() -> None: - """Run all Redis chat message store examples.""" - print("Redis Chat Message Store Examples") + """Run all Redis history provider examples.""" + print("Redis History Provider Examples") print("=" * 50) print("Prerequisites:") print("- Redis server running on localhost:6379") @@ -290,25 +233,12 @@ async def main() -> None: print("ERROR: OPENAI_API_KEY environment variable not set") return - try: - # Test Redis connection - test_store = RedisChatMessageStore(redis_url="redis://localhost:6379") - connection_ok = await test_store.ping() - await test_store.aclose() - if not connection_ok: - raise Exception("Redis ping failed") - print("✓ Redis connection successful\n") - except Exception as e: - print(f"ERROR: Cannot connect to Redis: {e}") - print("Please ensure Redis is running on localhost:6379") - return - try: # Run all examples await example_manual_memory_store() await example_user_session_management() await example_conversation_persistence() - await example_thread_serialization() + await example_session_serialization() await example_message_limits() print("All examples completed successfully!") diff --git a/python/samples/02-agents/conversations/suspend_resume_thread.py b/python/samples/02-agents/conversations/suspend_resume_thread.py index 5799505d02..dcbb00d06a 100644 --- a/python/samples/02-agents/conversations/suspend_resume_thread.py +++ b/python/samples/02-agents/conversations/suspend_resume_thread.py @@ -2,56 +2,57 @@ import asyncio +from agent_framework import AgentSession from agent_framework.azure import AzureAIAgentClient from agent_framework.openai import OpenAIChatClient from azure.identity.aio import AzureCliCredential """ -Thread Suspend and Resume Example +Session Suspend and Resume Example -This sample demonstrates how to suspend and resume conversation threads, comparing -service-managed threads (Azure AI) with in-memory threads (OpenAI) for persistent +This sample demonstrates how to suspend and resume conversation sessions, comparing +service-managed sessions (Azure AI) with in-memory sessions (OpenAI) for persistent conversation state across sessions. """ -async def suspend_resume_service_managed_thread() -> None: - """Demonstrates how to suspend and resume a service-managed thread.""" - print("=== Suspend-Resume Service-Managed Thread ===") +async def suspend_resume_service_managed_session() -> None: + """Demonstrates how to suspend and resume a service-managed session.""" + print("=== Suspend-Resume Service-Managed Session ===") - # AzureAIAgentClient supports service-managed threads. + # AzureAIAgentClient supports service-managed sessions. async with ( AzureCliCredential() as credential, AzureAIAgentClient(credential=credential).as_agent( name="MemoryBot", instructions="You are a helpful assistant that remembers our conversation." ) as agent, ): - # Start a new thread for the agent conversation. - thread = agent.get_new_thread() + # Start a new session for the agent conversation. + session = agent.create_session() # Respond to user input. query = "Hello! My name is Alice and I love pizza." print(f"User: {query}") - print(f"Agent: {await agent.run(query, thread=thread)}\n") + print(f"Agent: {await agent.run(query, session=session)}\n") - # Serialize the thread state, so it can be stored for later use. - serialized_thread = await thread.serialize() + # Serialize the session state, so it can be stored for later use. + serialized_session = session.to_dict() - # The thread can now be saved to a database, file, or any other storage mechanism and loaded again later. - print(f"Serialized thread: {serialized_thread}\n") + # The session can now be saved to a database, file, or any other storage mechanism and loaded again later. + print(f"Serialized session: {serialized_session}\n") - # Deserialize the thread state after loading from storage. - resumed_thread = await agent.deserialize_thread(serialized_thread) + # Deserialize the session state after loading from storage. + resumed_session = AgentSession.from_dict(serialized_session) # Respond to user input. query = "What do you remember about me?" print(f"User: {query}") - print(f"Agent: {await agent.run(query, thread=resumed_thread)}\n") + print(f"Agent: {await agent.run(query, session=resumed_session)}\n") -async def suspend_resume_in_memory_thread() -> None: - """Demonstrates how to suspend and resume an in-memory thread.""" - print("=== Suspend-Resume In-Memory Thread ===") +async def suspend_resume_in_memory_session() -> None: + """Demonstrates how to suspend and resume an in-memory session.""" + print("=== Suspend-Resume In-Memory Session ===") # OpenAI Chat Client is used as an example here, # other chat clients can be used as well. @@ -59,33 +60,33 @@ async def suspend_resume_in_memory_thread() -> None: name="MemoryBot", instructions="You are a helpful assistant that remembers our conversation." ) - # Start a new thread for the agent conversation. - thread = agent.get_new_thread() + # Start a new session for the agent conversation. + session = agent.create_session() # Respond to user input. query = "Hello! My name is Alice and I love pizza." print(f"User: {query}") - print(f"Agent: {await agent.run(query, thread=thread)}\n") + print(f"Agent: {await agent.run(query, session=session)}\n") - # Serialize the thread state, so it can be stored for later use. - serialized_thread = await thread.serialize() + # Serialize the session state, so it can be stored for later use. + serialized_session = session.to_dict() - # The thread can now be saved to a database, file, or any other storage mechanism and loaded again later. - print(f"Serialized thread: {serialized_thread}\n") + # The session can now be saved to a database, file, or any other storage mechanism and loaded again later. + print(f"Serialized session: {serialized_session}\n") - # Deserialize the thread state after loading from storage. - resumed_thread = await agent.deserialize_thread(serialized_thread) + # Deserialize the session state after loading from storage. + resumed_session = AgentSession.from_dict(serialized_session) # Respond to user input. query = "What do you remember about me?" print(f"User: {query}") - print(f"Agent: {await agent.run(query, thread=resumed_thread)}\n") + print(f"Agent: {await agent.run(query, session=resumed_session)}\n") async def main() -> None: - print("=== Suspend-Resume Thread Examples ===") - await suspend_resume_service_managed_thread() - await suspend_resume_in_memory_thread() + print("=== Suspend-Resume Session Examples ===") + await suspend_resume_service_managed_session() + await suspend_resume_in_memory_session() if __name__ == "__main__": diff --git a/python/samples/02-agents/middleware/thread_behavior_middleware.py b/python/samples/02-agents/middleware/thread_behavior_middleware.py index b09e50a5c0..d20393b456 100644 --- a/python/samples/02-agents/middleware/thread_behavior_middleware.py +++ b/python/samples/02-agents/middleware/thread_behavior_middleware.py @@ -6,7 +6,6 @@ from agent_framework import ( AgentContext, - ChatMessageStore, tool, ) from agent_framework.azure import AzureOpenAIChatClient @@ -16,19 +15,19 @@ """ Thread Behavior MiddlewareTypes Example -This sample demonstrates how middleware can access and track thread state across multiple agent runs. +This sample demonstrates how middleware can access and track session state across multiple agent runs. The example shows: -- How AgentContext.thread property behaves across multiple runs -- How middleware can access conversation history through the thread -- The timing of when thread messages are populated (before vs after call_next() call) -- How to track thread state changes across runs +- How AgentContext.session property behaves across multiple runs +- How middleware can access conversation history through the session +- The timing of when session messages are populated (before vs after call_next() call) +- How to track session state changes across runs Key behaviors demonstrated: -1. First run: context.messages is populated, context.thread is initially empty (before call_next()) -2. After call_next(): thread contains input message + response from agent -3. Second run: context.messages contains only current input, thread contains previous history -4. After call_next(): thread contains full conversation history (all previous + current messages) +1. First run: context.messages is populated, context.session is initially empty (before call_next()) +2. After call_next(): session contains input message + response from agent +3. Second run: context.messages contains only current input, session contains previous history +4. After call_next(): session contains full conversation history (all previous + current messages) """ @@ -48,28 +47,30 @@ async def thread_tracking_middleware( context: AgentContext, call_next: Callable[[], Awaitable[None]], ) -> None: - """MiddlewareTypes that tracks and logs thread behavior across runs.""" - thread_messages = [] - if context.thread and context.thread.message_store: - thread_messages = await context.thread.message_store.list_messages() + """MiddlewareTypes that tracks and logs session behavior across runs.""" + session_message_count = 0 + if context.session: + memory_state = context.session.state.get("memory", {}) + session_message_count = len(memory_state.get("messages", [])) print(f"[MiddlewareTypes pre-execution] Current input messages: {len(context.messages)}") - print(f"[MiddlewareTypes pre-execution] Thread history messages: {len(thread_messages)}") + print(f"[MiddlewareTypes pre-execution] Session history messages: {session_message_count}") # Call call_next to execute the agent await call_next() - # Check thread state after agent execution - updated_thread_messages = [] - if context.thread and context.thread.message_store: - updated_thread_messages = await context.thread.message_store.list_messages() + # Check session state after agent execution + updated_session_message_count = 0 + if context.session: + memory_state = context.session.state.get("memory", {}) + updated_session_message_count = len(memory_state.get("messages", [])) - print(f"[MiddlewareTypes post-execution] Updated thread messages: {len(updated_thread_messages)}") + print(f"[MiddlewareTypes post-execution] Updated session messages: {updated_session_message_count}") async def main() -> None: - """Example demonstrating thread behavior in middleware across multiple runs.""" - print("=== Thread Behavior MiddlewareTypes Example ===") + """Example demonstrating session behavior in middleware across multiple runs.""" + print("=== Session Behavior MiddlewareTypes Example ===") # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. @@ -78,23 +79,21 @@ async def main() -> None: instructions="You are a helpful weather assistant.", tools=get_weather, middleware=[thread_tracking_middleware], - # Configure agent with message store factory to persist conversation history - chat_message_store_factory=ChatMessageStore, ) - # Create a thread that will persist messages between runs - thread = agent.get_new_thread() + # Create a session that will persist messages between runs + session = agent.create_session() print("\nFirst Run:") query1 = "What's the weather like in Tokyo?" print(f"User: {query1}") - result1 = await agent.run(query1, thread=thread) + result1 = await agent.run(query1, session=session) print(f"Agent: {result1.text}") print("\nSecond Run:") query2 = "How about in London?" print(f"User: {query2}") - result2 = await agent.run(query2, thread=thread) + result2 = await agent.run(query2, session=session) print(f"Agent: {result2.text}") diff --git a/python/samples/02-agents/observability/agent_observability.py b/python/samples/02-agents/observability/agent_observability.py index 46bc92b74a..0cc7700625 100644 --- a/python/samples/02-agents/observability/agent_observability.py +++ b/python/samples/02-agents/observability/agent_observability.py @@ -46,13 +46,13 @@ async def main(): instructions="You are a weather assistant.", id="weather-agent", ) - thread = agent.get_new_thread() + session = agent.create_session() for question in questions: print(f"\nUser: {question}") print(f"{agent.name}: ", end="") async for update in agent.run( question, - thread=thread, + session=session, stream=True, ): if update.text: diff --git a/python/samples/02-agents/observability/agent_with_foundry_tracing.py b/python/samples/02-agents/observability/agent_with_foundry_tracing.py index bd46e81fef..242dbd080a 100644 --- a/python/samples/02-agents/observability/agent_with_foundry_tracing.py +++ b/python/samples/02-agents/observability/agent_with_foundry_tracing.py @@ -92,11 +92,11 @@ async def main(): instructions="You are a weather assistant.", id="weather-agent", ) - thread = agent.get_new_thread() + session = agent.create_session() for question in questions: print(f"\nUser: {question}") print(f"{agent.name}: ", end="") - async for update in agent.run(question, thread=thread, stream=True): + async for update in agent.run(question, session=session, stream=True): if update.text: print(update.text, end="") diff --git a/python/samples/02-agents/observability/azure_ai_agent_observability.py b/python/samples/02-agents/observability/azure_ai_agent_observability.py index 90946ad026..9395edcc38 100644 --- a/python/samples/02-agents/observability/azure_ai_agent_observability.py +++ b/python/samples/02-agents/observability/azure_ai_agent_observability.py @@ -63,11 +63,11 @@ async def main(): instructions="You are a weather assistant.", id="edvan-weather-agent", ) - thread = agent.get_new_thread() + session = agent.create_session() for question in questions: print(f"\nUser: {question}") print(f"{agent.name}: ", end="") - async for update in agent.run(question, thread=thread, stream=True): + async for update in agent.run(question, session=session, stream=True): if update.text: print(update.text, end="") diff --git a/python/samples/02-agents/providers/anthropic/anthropic_claude_with_session.py b/python/samples/02-agents/providers/anthropic/anthropic_claude_with_session.py index 2549457800..623be4f299 100644 --- a/python/samples/02-agents/providers/anthropic/anthropic_claude_with_session.py +++ b/python/samples/02-agents/providers/anthropic/anthropic_claude_with_session.py @@ -66,31 +66,31 @@ async def example_with_session_persistence() -> None: ) async with agent: - # Create a thread to maintain conversation context - thread = agent.get_new_thread() + # Create a session to maintain conversation context + session = agent.create_session() # First query query1 = "What's the weather like in Tokyo?" print(f"User: {query1}") - result1 = await agent.run(query1, thread=thread) + result1 = await agent.run(query1, session=session) print(f"Agent: {result1.text}") # Second query - using same thread maintains context query2 = "How about London?" print(f"\nUser: {query2}") - result2 = await agent.run(query2, thread=thread) + result2 = await agent.run(query2, session=session) print(f"Agent: {result2.text}") # Third query - agent should remember both previous cities query3 = "Which of the cities I asked about has better weather?" print(f"\nUser: {query3}") - result3 = await agent.run(query3, thread=thread) + result3 = await agent.run(query3, session=session) print(f"Agent: {result3.text}") print("Note: The agent remembers context from previous messages in the same session.\n") async def example_with_existing_session_id() -> None: - """Resume session in new agent instance using service_thread_id.""" + """Resume session in new agent instance using service_session_id.""" print("=== Existing Session ID Example ===") existing_session_id = None @@ -102,15 +102,15 @@ async def example_with_existing_session_id() -> None: ) async with agent1: - thread = agent1.get_new_thread() + session = agent1.create_session() query1 = "What's the weather in Paris?" print(f"User: {query1}") - result1 = await agent1.run(query1, thread=thread) + result1 = await agent1.run(query1, session=session) print(f"Agent: {result1.text}") # Capture the session ID for later use - existing_session_id = thread.service_thread_id + existing_session_id = session.service_session_id print(f"Session ID: {existing_session_id}") if existing_session_id: @@ -123,12 +123,12 @@ async def example_with_existing_session_id() -> None: ) async with agent2: - # Create thread with existing session ID - thread = agent2.get_new_thread(service_thread_id=existing_session_id) + # Create session with existing session ID + session = agent2.create_session(service_session_id=existing_session_id) query2 = "What was the last city I asked about?" print(f"User: {query2}") - result2 = await agent2.run(query2, thread=thread) + result2 = await agent2.run(query2, session=session) print(f"Agent: {result2.text}") print("Note: The agent continues the conversation using the session ID.\n") diff --git a/python/samples/02-agents/providers/azure_ai/azure_ai_with_existing_conversation.py b/python/samples/02-agents/providers/azure_ai/azure_ai_with_existing_conversation.py index cb18a7d46e..5c2872d2ed 100644 --- a/python/samples/02-agents/providers/azure_ai/azure_ai_with_existing_conversation.py +++ b/python/samples/02-agents/providers/azure_ai/azure_ai_with_existing_conversation.py @@ -61,9 +61,9 @@ async def example_with_conversation_id() -> None: print(f"Agent: {result.text}\n") -async def example_with_thread() -> None: - """This example shows how to specify existing conversation ID with AgentThread.""" - print("=== Azure AI Agent With Existing Conversation and Thread ===") +async def example_with_session() -> None: + """This example shows how to specify existing conversation ID with AgentSession.""" + print("=== Azure AI Agent With Existing Conversation and Session ===") async with ( AzureCliCredential() as credential, AIProjectClient(endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], credential=credential) as project_client, @@ -81,23 +81,23 @@ async def example_with_thread() -> None: conversation_id = conversation.id print(f"Conversation ID: {conversation_id}") - # Create a thread with the existing ID - thread = agent.get_new_thread(service_thread_id=conversation_id) + # Create a session with the existing ID + session = agent.create_session(service_session_id=conversation_id) query = "What's the weather like in Seattle?" print(f"User: {query}") - result = await agent.run(query, thread=thread) + result = await agent.run(query, session=session) print(f"Agent: {result.text}\n") query = "What was my last question?" print(f"User: {query}") - result = await agent.run(query, thread=thread) + result = await agent.run(query, session=session) print(f"Agent: {result.text}\n") async def main() -> None: await example_with_conversation_id() - await example_with_thread() + await example_with_session() if __name__ == "__main__": diff --git a/python/samples/02-agents/providers/azure_ai/azure_ai_with_hosted_mcp.py b/python/samples/02-agents/providers/azure_ai/azure_ai_with_hosted_mcp.py index 75ebd2ea76..02216a3014 100644 --- a/python/samples/02-agents/providers/azure_ai/azure_ai_with_hosted_mcp.py +++ b/python/samples/02-agents/providers/azure_ai/azure_ai_with_hosted_mcp.py @@ -3,7 +3,7 @@ import asyncio from typing import Any -from agent_framework import AgentResponse, AgentThread, Message, SupportsAgentRun +from agent_framework import AgentResponse, AgentSession, Message, SupportsAgentRun from agent_framework.azure import AzureAIClient, AzureAIProjectAgentProvider from azure.identity.aio import AzureCliCredential @@ -14,8 +14,8 @@ """ -async def handle_approvals_without_thread(query: str, agent: "SupportsAgentRun") -> AgentResponse: - """When we don't have a thread, we need to ensure we return with the input, approval request and approval.""" +async def handle_approvals_without_session(query: str, agent: "SupportsAgentRun") -> AgentResponse: + """When we don't have a session, we need to ensure we return with the input, approval request and approval.""" result = await agent.run(query, store=False) while len(result.user_input_requests) > 0: @@ -35,10 +35,10 @@ async def handle_approvals_without_thread(query: str, agent: "SupportsAgentRun") return result -async def handle_approvals_with_thread(query: str, agent: "SupportsAgentRun", thread: "AgentThread") -> AgentResponse: - """Here we let the thread deal with the previous responses, and we just rerun with the approval.""" +async def handle_approvals_with_session(query: str, agent: "SupportsAgentRun", session: "AgentSession") -> AgentResponse: + """Here we let the session deal with the previous responses, and we just rerun with the approval.""" - result = await agent.run(query, thread=thread) + result = await agent.run(query, session=session) while len(result.user_input_requests) > 0: new_input: list[Any] = [] for user_input_needed in result.user_input_requests: @@ -53,7 +53,7 @@ async def handle_approvals_with_thread(query: str, agent: "SupportsAgentRun", th contents=[user_input_needed.to_function_approval_response(user_approval.lower() == "y")], ) ) - result = await agent.run(new_input, thread=thread) + result = await agent.run(new_input, session=session) return result @@ -82,13 +82,13 @@ async def run_hosted_mcp_without_approval() -> None: query = "How to create an Azure storage account using az cli?" print(f"User: {query}") - result = await handle_approvals_without_thread(query, agent) + result = await handle_approvals_without_session(query, agent) print(f"{agent.name}: {result}\n") -async def run_hosted_mcp_with_approval_and_thread() -> None: - """Example showing MCP Tools with approvals using a thread.""" - print("=== MCP with approvals and with thread ===") +async def run_hosted_mcp_with_approval_and_session() -> None: + """Example showing MCP Tools with approvals using a session.""" + print("=== MCP with approvals and with session ===") # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. @@ -111,10 +111,10 @@ async def run_hosted_mcp_with_approval_and_thread() -> None: tools=[mcp_tool], ) - thread = agent.get_new_thread() + session = agent.create_session() query = "Please summarize the Azure REST API specifications Readme" print(f"User: {query}") - result = await handle_approvals_with_thread(query, agent, thread) + result = await handle_approvals_with_session(query, agent, session) print(f"{agent.name}: {result}\n") @@ -122,7 +122,7 @@ async def main() -> None: print("=== Azure AI Agent with Hosted MCP Tools Example ===\n") await run_hosted_mcp_without_approval() - await run_hosted_mcp_with_approval_and_thread() + await run_hosted_mcp_with_approval_and_session() if __name__ == "__main__": diff --git a/python/samples/02-agents/providers/azure_ai/azure_ai_with_thread.py b/python/samples/02-agents/providers/azure_ai/azure_ai_with_thread.py index 2abc053b71..f9a7fbf4fa 100644 --- a/python/samples/02-agents/providers/azure_ai/azure_ai_with_thread.py +++ b/python/samples/02-agents/providers/azure_ai/azure_ai_with_thread.py @@ -10,10 +10,10 @@ from pydantic import Field """ -Azure AI Agent with Thread Management Example +Azure AI Agent with Session Management Example -This sample demonstrates thread management with Azure AI Agent, showing -persistent conversation capabilities using service-managed threads as well as storing messages in-memory. +This sample demonstrates session management with Azure AI Agent, showing +persistent conversation capabilities using service-managed sessions as well as storing messages in-memory. """ @@ -30,9 +30,9 @@ def get_weather( return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." -async def example_with_automatic_thread_creation() -> None: - """Example showing automatic thread creation.""" - print("=== Automatic Thread Creation Example ===") +async def example_with_automatic_session_creation() -> None: + """Example showing automatic session creation.""" + print("=== Automatic Session Creation Example ===") async with ( AzureCliCredential() as credential, @@ -44,26 +44,26 @@ async def example_with_automatic_thread_creation() -> None: tools=get_weather, ) - # First conversation - no thread provided, will be created automatically + # First conversation - no session provided, will be created automatically query1 = "What's the weather like in Seattle?" print(f"User: {query1}") result1 = await agent.run(query1) print(f"Agent: {result1.text}") - # Second conversation - still no thread provided, will create another new thread + # Second conversation - still no session provided, will create another new session query2 = "What was the last city I asked about?" print(f"\nUser: {query2}") result2 = await agent.run(query2) print(f"Agent: {result2.text}") - print("Note: Each call creates a separate thread, so the agent doesn't remember previous context.\n") + print("Note: Each call creates a separate session, so the agent doesn't remember previous context.\n") -async def example_with_thread_persistence_in_memory() -> None: +async def example_with_session_persistence_in_memory() -> None: """ - Example showing thread persistence across multiple conversations. + Example showing session persistence across multiple conversations. In this example, messages are stored in-memory. """ - print("=== Thread Persistence Example (In-Memory) ===") + print("=== Session Persistence Example (In-Memory) ===") async with ( AzureCliCredential() as credential, @@ -75,38 +75,38 @@ async def example_with_thread_persistence_in_memory() -> None: tools=get_weather, ) - # Create a new thread that will be reused - thread = agent.get_new_thread() + # Create a new session that will be reused + session = agent.create_session() # First conversation first_query = "What's the weather like in Tokyo?" print(f"User: {first_query}") - first_result = await agent.run(first_query, thread=thread, options={"store": False}) + first_result = await agent.run(first_query, session=session, options={"store": False}) print(f"Agent: {first_result.text}") - # Second conversation using the same thread - maintains context + # Second conversation using the same session - maintains context second_query = "How about London?" print(f"\nUser: {second_query}") - second_result = await agent.run(second_query, thread=thread, options={"store": False}) + second_result = await agent.run(second_query, session=session, options={"store": False}) print(f"Agent: {second_result.text}") # Third conversation - agent should remember both previous cities third_query = "Which of the cities I asked about has better weather?" print(f"\nUser: {third_query}") - third_result = await agent.run(third_query, thread=thread, options={"store": False}) + third_result = await agent.run(third_query, session=session, options={"store": False}) print(f"Agent: {third_result.text}") - print("Note: The agent remembers context from previous messages in the same thread.\n") + print("Note: The agent remembers context from previous messages in the same session.\n") -async def example_with_existing_thread_id() -> None: +async def example_with_existing_session_id() -> None: """ - Example showing how to work with an existing thread ID from the service. + Example showing how to work with an existing session ID from the service. In this example, messages are stored on the server. """ - print("=== Existing Thread ID Example ===") + print("=== Existing Session ID Example ===") - # First, create a conversation and capture the thread ID - existing_thread_id = None + # First, create a conversation and capture the session ID + existing_session_id = None async with ( AzureCliCredential() as credential, @@ -118,20 +118,20 @@ async def example_with_existing_thread_id() -> None: tools=get_weather, ) - # Start a conversation and get the thread ID - thread = agent.get_new_thread() + # Start a conversation and get the session ID + session = agent.create_session() first_query = "What's the weather in Paris?" print(f"User: {first_query}") - first_result = await agent.run(first_query, thread=thread) + first_result = await agent.run(first_query, session=session) print(f"Agent: {first_result.text}") - # The thread ID is set after the first response - existing_thread_id = thread.service_thread_id - print(f"Thread ID: {existing_thread_id}") + # The session ID is set after the first response + existing_session_id = session.service_session_id + print(f"Session ID: {existing_session_id}") - if existing_thread_id: - print("\n--- Continuing with the same thread ID in a new agent instance ---") + if existing_session_id: + print("\n--- Continuing with the same session ID in a new agent instance ---") # Create a new agent instance from the same provider second_agent = await provider.create_agent( @@ -140,22 +140,22 @@ async def example_with_existing_thread_id() -> None: tools=get_weather, ) - # Create a thread with the existing ID - thread = second_agent.get_new_thread(service_thread_id=existing_thread_id) + # Create a session with the existing ID + session = second_agent.create_session(service_session_id=existing_session_id) second_query = "What was the last city I asked about?" print(f"User: {second_query}") - second_result = await second_agent.run(second_query, thread=thread) + second_result = await second_agent.run(second_query, session=session) print(f"Agent: {second_result.text}") - print("Note: The agent continues the conversation from the previous thread by using thread ID.\n") + print("Note: The agent continues the conversation from the previous session by using session ID.\n") async def main() -> None: - print("=== Azure AI Agent Thread Management Examples ===\n") + print("=== Azure AI Agent Session Management Examples ===\n") - await example_with_automatic_thread_creation() - await example_with_thread_persistence_in_memory() - await example_with_existing_thread_id() + await example_with_automatic_session_creation() + await example_with_session_persistence_in_memory() + await example_with_existing_session_id() if __name__ == "__main__": diff --git a/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_existing_thread.py b/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_existing_thread.py index 9d6148ad59..64b736074a 100644 --- a/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_existing_thread.py +++ b/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_existing_thread.py @@ -12,10 +12,10 @@ from pydantic import Field """ -Azure AI Agent with Existing Thread Example +Azure AI Agent with Existing Session Example -This sample demonstrates working with pre-existing conversation threads -by providing thread IDs for thread reuse patterns. +This sample demonstrates working with pre-existing conversation sessions +by providing session IDs for session reuse patterns. """ @@ -32,7 +32,7 @@ def get_weather( async def main() -> None: - print("=== Azure AI Agent with Existing Thread ===") + print("=== Azure AI Agent with Existing Session ===") # Create the client and provider async with ( @@ -40,7 +40,7 @@ async def main() -> None: AgentsClient(endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], credential=credential) as agents_client, AzureAIAgentsProvider(agents_client=agents_client) as provider, ): - # Create a thread that will persist + # Create a session that will persist created_thread = await agents_client.threads.create() try: @@ -51,12 +51,12 @@ async def main() -> None: tools=get_weather, ) - thread = agent.get_new_thread(service_thread_id=created_thread.id) - assert thread.is_initialized - result = await agent.run("What's the weather like in Tokyo?", thread=thread) + session = agent.create_session(service_session_id=created_thread.id) + assert session.is_initialized + result = await agent.run("What's the weather like in Tokyo?", session=session) print(f"Result: {result}\n") finally: - # Clean up the thread manually + # Clean up the session manually await agents_client.threads.delete(created_thread.id) diff --git a/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_hosted_mcp.py b/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_hosted_mcp.py index 4a8e234241..9a64bae9a1 100644 --- a/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_hosted_mcp.py +++ b/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_hosted_mcp.py @@ -3,7 +3,7 @@ import asyncio from typing import Any -from agent_framework import AgentResponse, AgentThread, SupportsAgentRun +from agent_framework import AgentResponse, AgentSession, SupportsAgentRun from agent_framework.azure import AzureAIAgentClient, AzureAIAgentsProvider from azure.identity.aio import AzureCliCredential @@ -15,11 +15,11 @@ """ -async def handle_approvals_with_thread(query: str, agent: "SupportsAgentRun", thread: "AgentThread") -> AgentResponse: - """Here we let the thread deal with the previous responses, and we just rerun with the approval.""" +async def handle_approvals_with_session(query: str, agent: "SupportsAgentRun", session: "AgentSession") -> AgentResponse: + """Here we let the session deal with the previous responses, and we just rerun with the approval.""" from agent_framework import Message - result = await agent.run(query, thread=thread, store=True) + result = await agent.run(query, session=session, store=True) while len(result.user_input_requests) > 0: new_input: list[Any] = [] for user_input_needed in result.user_input_requests: @@ -34,7 +34,7 @@ async def handle_approvals_with_thread(query: str, agent: "SupportsAgentRun", th contents=[user_input_needed.to_function_approval_response(user_approval.lower() == "y")], ) ) - result = await agent.run(new_input, thread=thread, store=True) + result = await agent.run(new_input, session=session, store=True) return result @@ -58,17 +58,17 @@ async def main() -> None: instructions="You are a helpful assistant that can help with microsoft documentation questions.", tools=[mcp_tool], ) - thread = agent.get_new_thread() + session = agent.create_session() # First query query1 = "How to create an Azure storage account using az cli?" print(f"User: {query1}") - result1 = await handle_approvals_with_thread(query1, agent, thread) + result1 = await handle_approvals_with_session(query1, agent, session) print(f"{agent.name}: {result1}\n") print("\n=======================================\n") # Second query query2 = "What is Microsoft Agent Framework?" print(f"User: {query2}") - result2 = await handle_approvals_with_thread(query2, agent, thread) + result2 = await handle_approvals_with_session(query2, agent, session) print(f"{agent.name}: {result2}\n") diff --git a/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_multiple_tools.py b/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_multiple_tools.py index 35e6650748..4b3cfeadb7 100644 --- a/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_multiple_tools.py +++ b/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_multiple_tools.py @@ -5,7 +5,7 @@ from typing import Any from agent_framework import ( - AgentThread, + AgentSession, SupportsAgentRun, tool, ) @@ -43,11 +43,11 @@ def get_time() -> str: return f"The current UTC time is {current_time.strftime('%Y-%m-%d %H:%M:%S')}." -async def handle_approvals_with_thread(query: str, agent: "SupportsAgentRun", thread: "AgentThread"): - """Here we let the thread deal with the previous responses, and we just rerun with the approval.""" +async def handle_approvals_with_session(query: str, agent: "SupportsAgentRun", session: "AgentSession"): + """Here we let the session deal with the previous responses, and we just rerun with the approval.""" from agent_framework import Message - result = await agent.run(query, thread=thread, store=True) + result = await agent.run(query, session=session, store=True) while len(result.user_input_requests) > 0: new_input: list[Any] = [] for user_input_needed in result.user_input_requests: @@ -62,7 +62,7 @@ async def handle_approvals_with_thread(query: str, agent: "SupportsAgentRun", th contents=[user_input_needed.to_function_approval_response(user_approval.lower() == "y")], ) ) - result = await agent.run(new_input, thread=thread, store=True) + result = await agent.run(new_input, session=session, store=True) return result @@ -91,17 +91,17 @@ async def main() -> None: get_time, ], ) - thread = agent.get_new_thread() + session = agent.create_session() # First query query1 = "How to create an Azure storage account using az cli and what time is it?" print(f"User: {query1}") - result1 = await handle_approvals_with_thread(query1, agent, thread) + result1 = await handle_approvals_with_session(query1, agent, session) print(f"{agent.name}: {result1}\n") print("\n=======================================\n") # Second query query2 = "What is Microsoft Agent Framework and use a web search to see what is Reddit saying about it?" print(f"User: {query2}") - result2 = await handle_approvals_with_thread(query2, agent, thread) + result2 = await handle_approvals_with_session(query2, agent, session) print(f"{agent.name}: {result2}\n") diff --git a/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_openapi_tools.py b/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_openapi_tools.py index 125b982834..ff5ad8c8dc 100644 --- a/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_openapi_tools.py +++ b/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_openapi_tools.py @@ -76,16 +76,16 @@ async def main() -> None: tools=[*openapi_countries.definitions, *openapi_weather.definitions], ) - # 5. Simulate conversation with the agent maintaining thread context + # 5. Simulate conversation with the agent maintaining session context print("=== Azure AI Agent with OpenAPI Tools ===\n") - # Create a thread to maintain conversation context across multiple runs - thread = agent.get_new_thread() + # Create a session to maintain conversation context across multiple runs + session = agent.create_session() for user_input in USER_INPUTS: print(f"User: {user_input}") - # Pass the thread to maintain context across multiple agent.run() calls - response = await agent.run(user_input, thread=thread) + # Pass the session to maintain context across multiple agent.run() calls + response = await agent.run(user_input, session=session) print(f"Agent: {response.text}\n") diff --git a/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_thread.py b/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_thread.py index e5957905ae..190c002747 100644 --- a/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_thread.py +++ b/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_thread.py @@ -4,16 +4,16 @@ from random import randint from typing import Annotated -from agent_framework import AgentThread, tool +from agent_framework import AgentSession, tool from agent_framework.azure import AzureAIAgentsProvider from azure.identity.aio import AzureCliCredential from pydantic import Field """ -Azure AI Agent with Thread Management Example +Azure AI Agent with Session Management Example -This sample demonstrates thread management with Azure AI Agents, comparing -automatic thread creation with explicit thread management for persistent context. +This sample demonstrates session management with Azure AI Agents, comparing +automatic session creation with explicit session management for persistent context. """ @@ -29,9 +29,9 @@ def get_weather( return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." -async def example_with_automatic_thread_creation() -> None: - """Example showing automatic thread creation (service-managed thread).""" - print("=== Automatic Thread Creation Example ===") +async def example_with_automatic_session_creation() -> None: + """Example showing automatic session creation (service-managed session).""" + print("=== Automatic Session Creation Example ===") # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. @@ -45,24 +45,24 @@ async def example_with_automatic_thread_creation() -> None: tools=get_weather, ) - # First conversation - no thread provided, will be created automatically + # First conversation - no session provided, will be created automatically first_query = "What's the weather like in Seattle?" print(f"User: {first_query}") first_result = await agent.run(first_query) print(f"Agent: {first_result.text}") - # Second conversation - still no thread provided, will create another new thread + # Second conversation - still no session provided, will create another new session second_query = "What was the last city I asked about?" print(f"\nUser: {second_query}") second_result = await agent.run(second_query) print(f"Agent: {second_result.text}") - print("Note: Each call creates a separate thread, so the agent doesn't remember previous context.\n") + print("Note: Each call creates a separate session, so the agent doesn't remember previous context.\n") -async def example_with_thread_persistence() -> None: - """Example showing thread persistence across multiple conversations.""" - print("=== Thread Persistence Example ===") - print("Using the same thread across multiple conversations to maintain context.\n") +async def example_with_session_persistence() -> None: + """Example showing session persistence across multiple conversations.""" + print("=== Session Persistence Example ===") + print("Using the same session across multiple conversations to maintain context.\n") # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. @@ -76,36 +76,36 @@ async def example_with_thread_persistence() -> None: tools=get_weather, ) - # Create a new thread that will be reused - thread = agent.get_new_thread() + # Create a new session that will be reused + session = agent.create_session() # First conversation first_query = "What's the weather like in Tokyo?" print(f"User: {first_query}") - first_result = await agent.run(first_query, thread=thread) + first_result = await agent.run(first_query, session=session) print(f"Agent: {first_result.text}") - # Second conversation using the same thread - maintains context + # Second conversation using the same session - maintains context second_query = "How about London?" print(f"\nUser: {second_query}") - second_result = await agent.run(second_query, thread=thread) + second_result = await agent.run(second_query, session=session) print(f"Agent: {second_result.text}") # Third conversation - agent should remember both previous cities third_query = "Which of the cities I asked about has better weather?" print(f"\nUser: {third_query}") - third_result = await agent.run(third_query, thread=thread) + third_result = await agent.run(third_query, session=session) print(f"Agent: {third_result.text}") - print("Note: The agent remembers context from previous messages in the same thread.\n") + print("Note: The agent remembers context from previous messages in the same session.\n") -async def example_with_existing_thread_id() -> None: - """Example showing how to work with an existing thread ID from the service.""" - print("=== Existing Thread ID Example ===") - print("Using a specific thread ID to continue an existing conversation.\n") +async def example_with_existing_session_id() -> None: + """Example showing how to work with an existing session ID from the service.""" + print("=== Existing Session ID Example ===") + print("Using a specific session ID to continue an existing conversation.\n") - # First, create a conversation and capture the thread ID - existing_thread_id = None + # First, create a conversation and capture the session ID + existing_session_id = None # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. @@ -119,21 +119,21 @@ async def example_with_existing_thread_id() -> None: tools=get_weather, ) - # Start a conversation and get the thread ID - thread = agent.get_new_thread() + # Start a conversation and get the session ID + session = agent.create_session() first_query = "What's the weather in Paris?" print(f"User: {first_query}") - first_result = await agent.run(first_query, thread=thread) + first_result = await agent.run(first_query, session=session) print(f"Agent: {first_result.text}") - # The thread ID is set after the first response - existing_thread_id = thread.service_thread_id - print(f"Thread ID: {existing_thread_id}") + # The session ID is set after the first response + existing_session_id = session.service_session_id + print(f"Session ID: {existing_session_id}") - if existing_thread_id: - print("\n--- Continuing with the same thread ID in a new agent instance ---") + if existing_session_id: + print("\n--- Continuing with the same session ID in a new agent instance ---") - # Create a new provider and agent but use the existing thread ID + # Create a new provider and agent but use the existing session ID async with ( AzureCliCredential() as credential, AzureAIAgentsProvider(credential=credential) as provider, @@ -144,22 +144,22 @@ async def example_with_existing_thread_id() -> None: tools=get_weather, ) - # Create a thread with the existing ID - thread = AgentThread(service_thread_id=existing_thread_id) + # Create a session with the existing ID + session = AgentSession(service_session_id=existing_session_id) second_query = "What was the last city I asked about?" print(f"User: {second_query}") - second_result = await agent.run(second_query, thread=thread) + second_result = await agent.run(second_query, session=session) print(f"Agent: {second_result.text}") - print("Note: The agent continues the conversation from the previous thread.\n") + print("Note: The agent continues the conversation from the previous session.\n") async def main() -> None: - print("=== Azure AI Chat Client Agent Thread Management Examples ===\n") + print("=== Azure AI Chat Client Agent Session Management Examples ===\n") - await example_with_automatic_thread_creation() - await example_with_thread_persistence() - await example_with_existing_thread_id() + await example_with_automatic_session_creation() + await example_with_session_persistence() + await example_with_existing_session_id() if __name__ == "__main__": diff --git a/python/samples/02-agents/providers/azure_openai/azure_assistants_with_thread.py b/python/samples/02-agents/providers/azure_openai/azure_assistants_with_thread.py index dce9699a96..edc9b2edbb 100644 --- a/python/samples/02-agents/providers/azure_openai/azure_assistants_with_thread.py +++ b/python/samples/02-agents/providers/azure_openai/azure_assistants_with_thread.py @@ -4,16 +4,16 @@ from random import randint from typing import Annotated -from agent_framework import Agent, AgentThread, tool +from agent_framework import Agent, AgentSession, tool from agent_framework.azure import AzureOpenAIAssistantsClient from azure.identity import AzureCliCredential from pydantic import Field """ -Azure OpenAI Assistants with Thread Management Example +Azure OpenAI Assistants with Session Management Example -This sample demonstrates thread management with Azure OpenAI Assistants, comparing -automatic thread creation with explicit thread management for persistent context. +This sample demonstrates session management with Azure OpenAI Assistants, comparing +automatic session creation with explicit session management for persistent context. """ @@ -29,9 +29,9 @@ def get_weather( return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." -async def example_with_automatic_thread_creation() -> None: - """Example showing automatic thread creation (service-managed thread).""" - print("=== Automatic Thread Creation Example ===") +async def example_with_automatic_session_creation() -> None: + """Example showing automatic session creation (service-managed session).""" + print("=== Automatic Session Creation Example ===") # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. @@ -40,24 +40,24 @@ async def example_with_automatic_thread_creation() -> None: instructions="You are a helpful weather agent.", tools=get_weather, ) as agent: - # First conversation - no thread provided, will be created automatically + # First conversation - no session provided, will be created automatically query1 = "What's the weather like in Seattle?" print(f"User: {query1}") result1 = await agent.run(query1) print(f"Agent: {result1.text}") - # Second conversation - still no thread provided, will create another new thread + # Second conversation - still no session provided, will create another new session query2 = "What was the last city I asked about?" print(f"\nUser: {query2}") result2 = await agent.run(query2) print(f"Agent: {result2.text}") - print("Note: Each call creates a separate thread, so the agent doesn't remember previous context.\n") + print("Note: Each call creates a separate session, so the agent doesn't remember previous context.\n") -async def example_with_thread_persistence() -> None: - """Example showing thread persistence across multiple conversations.""" - print("=== Thread Persistence Example ===") - print("Using the same thread across multiple conversations to maintain context.\n") +async def example_with_session_persistence() -> None: + """Example showing session persistence across multiple conversations.""" + print("=== Session Persistence Example ===") + print("Using the same session across multiple conversations to maintain context.\n") # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. @@ -66,36 +66,36 @@ async def example_with_thread_persistence() -> None: instructions="You are a helpful weather agent.", tools=get_weather, ) as agent: - # Create a new thread that will be reused - thread = agent.get_new_thread() + # Create a new session that will be reused + session = agent.create_session() # First conversation query1 = "What's the weather like in Tokyo?" print(f"User: {query1}") - result1 = await agent.run(query1, thread=thread) + result1 = await agent.run(query1, session=session) print(f"Agent: {result1.text}") - # Second conversation using the same thread - maintains context + # Second conversation using the same session - maintains context query2 = "How about London?" print(f"\nUser: {query2}") - result2 = await agent.run(query2, thread=thread) + result2 = await agent.run(query2, session=session) print(f"Agent: {result2.text}") # Third conversation - agent should remember both previous cities query3 = "Which of the cities I asked about has better weather?" print(f"\nUser: {query3}") - result3 = await agent.run(query3, thread=thread) + result3 = await agent.run(query3, session=session) print(f"Agent: {result3.text}") - print("Note: The agent remembers context from previous messages in the same thread.\n") + print("Note: The agent remembers context from previous messages in the same session.\n") -async def example_with_existing_thread_id() -> None: - """Example showing how to work with an existing thread ID from the service.""" - print("=== Existing Thread ID Example ===") - print("Using a specific thread ID to continue an existing conversation.\n") +async def example_with_existing_session_id() -> None: + """Example showing how to work with an existing session ID from the service.""" + print("=== Existing Session ID Example ===") + print("Using a specific session ID to continue an existing conversation.\n") - # First, create a conversation and capture the thread ID - existing_thread_id = None + # First, create a conversation and capture the session ID + existing_session_id = None # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. @@ -104,42 +104,42 @@ async def example_with_existing_thread_id() -> None: instructions="You are a helpful weather agent.", tools=get_weather, ) as agent: - # Start a conversation and get the thread ID - thread = agent.get_new_thread() + # Start a conversation and get the session ID + session = agent.create_session() query1 = "What's the weather in Paris?" print(f"User: {query1}") - result1 = await agent.run(query1, thread=thread) + result1 = await agent.run(query1, session=session) print(f"Agent: {result1.text}") - # The thread ID is set after the first response - existing_thread_id = thread.service_thread_id - print(f"Thread ID: {existing_thread_id}") + # The session ID is set after the first response + existing_session_id = session.service_session_id + print(f"Session ID: {existing_session_id}") - if existing_thread_id: - print("\n--- Continuing with the same thread ID in a new agent instance ---") + if existing_session_id: + print("\n--- Continuing with the same session ID in a new agent instance ---") - # Create a new agent instance but use the existing thread ID + # Create a new agent instance but use the existing session ID async with Agent( - client=AzureOpenAIAssistantsClient(thread_id=existing_thread_id, credential=AzureCliCredential()), + client=AzureOpenAIAssistantsClient(thread_id=existing_session_id, credential=AzureCliCredential()), instructions="You are a helpful weather agent.", tools=get_weather, ) as agent: - # Create a thread with the existing ID - thread = AgentThread(service_thread_id=existing_thread_id) + # Create a session with the existing ID + session = AgentSession(service_session_id=existing_session_id) query2 = "What was the last city I asked about?" print(f"User: {query2}") - result2 = await agent.run(query2, thread=thread) + result2 = await agent.run(query2, session=session) print(f"Agent: {result2.text}") - print("Note: The agent continues the conversation from the previous thread.\n") + print("Note: The agent continues the conversation from the previous session.\n") async def main() -> None: - print("=== Azure OpenAI Assistants Chat Client Agent Thread Management Examples ===\n") + print("=== Azure OpenAI Assistants Chat Client Agent Session Management Examples ===\n") - await example_with_automatic_thread_creation() - await example_with_thread_persistence() - await example_with_existing_thread_id() + await example_with_automatic_session_creation() + await example_with_session_persistence() + await example_with_existing_session_id() if __name__ == "__main__": diff --git a/python/samples/02-agents/providers/azure_openai/azure_chat_client_with_thread.py b/python/samples/02-agents/providers/azure_openai/azure_chat_client_with_thread.py index bdad0faa3f..1382a14843 100644 --- a/python/samples/02-agents/providers/azure_openai/azure_chat_client_with_thread.py +++ b/python/samples/02-agents/providers/azure_openai/azure_chat_client_with_thread.py @@ -4,16 +4,16 @@ from random import randint from typing import Annotated -from agent_framework import Agent, AgentThread, ChatMessageStore, tool +from agent_framework import Agent, AgentSession, tool from agent_framework.azure import AzureOpenAIChatClient from azure.identity import AzureCliCredential from pydantic import Field """ -Azure OpenAI Chat Client with Thread Management Example +Azure OpenAI Chat Client with Session Management Example -This sample demonstrates thread management with Azure OpenAI Chat Client, comparing -automatic thread creation with explicit thread management for persistent context. +This sample demonstrates session management with Azure OpenAI Chat Client, comparing +automatic session creation with explicit session management for persistent context. """ @@ -29,9 +29,9 @@ def get_weather( return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." -async def example_with_automatic_thread_creation() -> None: - """Example showing automatic thread creation (service-managed thread).""" - print("=== Automatic Thread Creation Example ===") +async def example_with_automatic_session_creation() -> None: + """Example showing automatic session creation (service-managed session).""" + print("=== Automatic Session Creation Example ===") # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. @@ -41,24 +41,24 @@ async def example_with_automatic_thread_creation() -> None: tools=get_weather, ) - # First conversation - no thread provided, will be created automatically + # First conversation - no session provided, will be created automatically query1 = "What's the weather like in Seattle?" print(f"User: {query1}") result1 = await agent.run(query1) print(f"Agent: {result1.text}") - # Second conversation - still no thread provided, will create another new thread + # Second conversation - still no session provided, will create another new session query2 = "What was the last city I asked about?" print(f"\nUser: {query2}") result2 = await agent.run(query2) print(f"Agent: {result2.text}") - print("Note: Each call creates a separate thread, so the agent doesn't remember previous context.\n") + print("Note: Each call creates a separate session, so the agent doesn't remember previous context.\n") -async def example_with_thread_persistence() -> None: - """Example showing thread persistence across multiple conversations.""" - print("=== Thread Persistence Example ===") - print("Using the same thread across multiple conversations to maintain context.\n") +async def example_with_session_persistence() -> None: + """Example showing session persistence across multiple conversations.""" + print("=== Session Persistence Example ===") + print("Using the same session across multiple conversations to maintain context.\n") # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. @@ -68,32 +68,32 @@ async def example_with_thread_persistence() -> None: tools=get_weather, ) - # Create a new thread that will be reused - thread = agent.get_new_thread() + # Create a new session that will be reused + session = agent.create_session() # First conversation query1 = "What's the weather like in Tokyo?" print(f"User: {query1}") - result1 = await agent.run(query1, thread=thread) + result1 = await agent.run(query1, session=session) print(f"Agent: {result1.text}") - # Second conversation using the same thread - maintains context + # Second conversation using the same session - maintains context query2 = "How about London?" print(f"\nUser: {query2}") - result2 = await agent.run(query2, thread=thread) + result2 = await agent.run(query2, session=session) print(f"Agent: {result2.text}") # Third conversation - agent should remember both previous cities query3 = "Which of the cities I asked about has better weather?" print(f"\nUser: {query3}") - result3 = await agent.run(query3, thread=thread) + result3 = await agent.run(query3, session=session) print(f"Agent: {result3.text}") - print("Note: The agent remembers context from previous messages in the same thread.\n") + print("Note: The agent remembers context from previous messages in the same session.\n") -async def example_with_existing_thread_messages() -> None: - """Example showing how to work with existing thread messages for Azure.""" - print("=== Existing Thread Messages Example ===") +async def example_with_existing_session_messages() -> None: + """Example showing how to work with existing session messages for Azure.""" + print("=== Existing Session Messages Example ===") # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. @@ -104,53 +104,53 @@ async def example_with_existing_thread_messages() -> None: ) # Start a conversation and build up message history - thread = agent.get_new_thread() + session = agent.create_session() query1 = "What's the weather in Paris?" print(f"User: {query1}") - result1 = await agent.run(query1, thread=thread) + result1 = await agent.run(query1, session=session) print(f"Agent: {result1.text}") - # The thread now contains the conversation history in memory - if thread.message_store: - messages = await thread.message_store.list_messages() - print(f"Thread contains {len(messages or [])} messages") + # The session now contains the conversation history in memory + if session.message_store: + messages = await session.message_store.list_messages() + print(f"Session contains {len(messages or [])} messages") - print("\n--- Continuing with the same thread in a new agent instance ---") + print("\n--- Continuing with the same session in a new agent instance ---") - # Create a new agent instance but use the existing thread with its message history + # Create a new agent instance but use the existing session with its message history new_agent = Agent( client=AzureOpenAIChatClient(credential=AzureCliCredential()), instructions="You are a helpful weather agent.", tools=get_weather, ) - # Use the same thread object which contains the conversation history + # Use the same session object which contains the conversation history query2 = "What was the last city I asked about?" print(f"User: {query2}") - result2 = await new_agent.run(query2, thread=thread) + result2 = await new_agent.run(query2, session=session) print(f"Agent: {result2.text}") print("Note: The agent continues the conversation using the local message history.\n") - print("\n--- Alternative: Creating a new thread from existing messages ---") + print("\n--- Alternative: Creating a new session from existing messages ---") - # You can also create a new thread from existing messages - messages = await thread.message_store.list_messages() if thread.message_store else [] - new_thread = AgentThread(message_store=ChatMessageStore(messages)) + # You can also create a new session from existing messages + messages = await session.message_store.list_messages() if session.message_store else [] + new_session = AgentSession() query3 = "How does the Paris weather compare to London?" print(f"User: {query3}") - result3 = await new_agent.run(query3, thread=new_thread) + result3 = await new_agent.run(query3, session=new_session) print(f"Agent: {result3.text}") - print("Note: This creates a new thread with the same conversation history.\n") + print("Note: This creates a new session with the same conversation history.\n") async def main() -> None: - print("=== Azure Chat Client Agent Thread Management Examples ===\n") + print("=== Azure Chat Client Agent Session Management Examples ===\n") - await example_with_automatic_thread_creation() - await example_with_thread_persistence() - await example_with_existing_thread_messages() + await example_with_automatic_session_creation() + await example_with_session_persistence() + await example_with_existing_session_messages() if __name__ == "__main__": diff --git a/python/samples/02-agents/providers/azure_openai/azure_responses_client_with_hosted_mcp.py b/python/samples/02-agents/providers/azure_openai/azure_responses_client_with_hosted_mcp.py index bcc6f636b5..9de272c62a 100644 --- a/python/samples/02-agents/providers/azure_openai/azure_responses_client_with_hosted_mcp.py +++ b/python/samples/02-agents/providers/azure_openai/azure_responses_client_with_hosted_mcp.py @@ -15,11 +15,11 @@ """ if TYPE_CHECKING: - from agent_framework import AgentThread, SupportsAgentRun + from agent_framework import AgentSession, SupportsAgentRun -async def handle_approvals_without_thread(query: str, agent: "SupportsAgentRun"): - """When we don't have a thread, we need to ensure we return with the input, approval request and approval.""" +async def handle_approvals_without_session(query: str, agent: "SupportsAgentRun"): + """When we don't have a session, we need to ensure we return with the input, approval request and approval.""" from agent_framework import Message result = await agent.run(query) @@ -43,11 +43,11 @@ async def handle_approvals_without_thread(query: str, agent: "SupportsAgentRun") return result -async def handle_approvals_with_thread(query: str, agent: "SupportsAgentRun", thread: "AgentThread"): - """Here we let the thread deal with the previous responses, and we just rerun with the approval.""" +async def handle_approvals_with_session(query: str, agent: "SupportsAgentRun", session: "AgentSession"): + """Here we let the session deal with the previous responses, and we just rerun with the approval.""" from agent_framework import Message - result = await agent.run(query, thread=thread, store=True) + result = await agent.run(query, session=session, store=True) while len(result.user_input_requests) > 0: new_input: list[Any] = [] for user_input_needed in result.user_input_requests: @@ -62,12 +62,12 @@ async def handle_approvals_with_thread(query: str, agent: "SupportsAgentRun", th contents=[user_input_needed.to_function_approval_response(user_approval.lower() == "y")], ) ) - result = await agent.run(new_input, thread=thread, store=True) + result = await agent.run(new_input, session=session, store=True) return result -async def handle_approvals_with_thread_streaming(query: str, agent: "SupportsAgentRun", thread: "AgentThread"): - """Here we let the thread deal with the previous responses, and we just rerun with the approval.""" +async def handle_approvals_with_session_streaming(query: str, agent: "SupportsAgentRun", session: "AgentSession"): + """Here we let the session deal with the previous responses, and we just rerun with the approval.""" from agent_framework import Message new_input: list[Message] = [] @@ -75,7 +75,7 @@ async def handle_approvals_with_thread_streaming(query: str, agent: "SupportsAge while new_input_added: new_input_added = False new_input.append(Message(role="user", text=query)) - async for update in agent.run(new_input, thread=thread, options={"store": True}, stream=True): + async for update in agent.run(new_input, session=session, options={"store": True}, stream=True): if update.user_input_requests: for user_input_needed in update.user_input_requests: print( @@ -94,9 +94,9 @@ async def handle_approvals_with_thread_streaming(query: str, agent: "SupportsAge yield update -async def run_hosted_mcp_without_thread_and_specific_approval() -> None: - """Example showing Mcp Tools with approvals without using a thread.""" - print("=== Mcp with approvals and without thread ===") +async def run_hosted_mcp_without_session_and_specific_approval() -> None: + """Example showing Mcp Tools with approvals without using a session.""" + print("=== Mcp with approvals and without session ===") credential = AzureCliCredential() client = AzureOpenAIResponsesClient(credential=credential) @@ -120,13 +120,13 @@ async def run_hosted_mcp_without_thread_and_specific_approval() -> None: # First query query1 = "How to create an Azure storage account using az cli?" print(f"User: {query1}") - result1 = await handle_approvals_without_thread(query1, agent) + result1 = await handle_approvals_without_session(query1, agent) print(f"{agent.name}: {result1}\n") print("\n=======================================\n") # Second query query2 = "What is Microsoft Agent Framework?" print(f"User: {query2}") - result2 = await handle_approvals_without_thread(query2, agent) + result2 = await handle_approvals_without_session(query2, agent) print(f"{agent.name}: {result2}\n") @@ -157,19 +157,19 @@ async def run_hosted_mcp_without_approval() -> None: # First query query1 = "How to create an Azure storage account using az cli?" print(f"User: {query1}") - result1 = await handle_approvals_without_thread(query1, agent) + result1 = await handle_approvals_without_session(query1, agent) print(f"{agent.name}: {result1}\n") print("\n=======================================\n") # Second query query2 = "What is Microsoft Agent Framework?" print(f"User: {query2}") - result2 = await handle_approvals_without_thread(query2, agent) + result2 = await handle_approvals_without_session(query2, agent) print(f"{agent.name}: {result2}\n") -async def run_hosted_mcp_with_thread() -> None: - """Example showing Mcp Tools with approvals using a thread.""" - print("=== Mcp with approvals and with thread ===") +async def run_hosted_mcp_with_session() -> None: + """Example showing Mcp Tools with approvals using a session.""" + print("=== Mcp with approvals and with session ===") credential = AzureCliCredential() client = AzureOpenAIResponsesClient(credential=credential) @@ -190,22 +190,22 @@ async def run_hosted_mcp_with_thread() -> None: tools=[mcp_tool], ) as agent: # First query - thread = agent.get_new_thread() + session = agent.create_session() query1 = "How to create an Azure storage account using az cli?" print(f"User: {query1}") - result1 = await handle_approvals_with_thread(query1, agent, thread) + result1 = await handle_approvals_with_session(query1, agent, session) print(f"{agent.name}: {result1}\n") print("\n=======================================\n") # Second query query2 = "What is Microsoft Agent Framework?" print(f"User: {query2}") - result2 = await handle_approvals_with_thread(query2, agent, thread) + result2 = await handle_approvals_with_session(query2, agent, session) print(f"{agent.name}: {result2}\n") -async def run_hosted_mcp_with_thread_streaming() -> None: - """Example showing Mcp Tools with approvals using a thread.""" - print("=== Mcp with approvals and with thread ===") +async def run_hosted_mcp_with_session_streaming() -> None: + """Example showing Mcp Tools with approvals using a session.""" + print("=== Mcp with approvals and with session ===") credential = AzureCliCredential() client = AzureOpenAIResponsesClient(credential=credential) @@ -226,11 +226,11 @@ async def run_hosted_mcp_with_thread_streaming() -> None: tools=[mcp_tool], ) as agent: # First query - thread = agent.get_new_thread() + session = agent.create_session() query1 = "How to create an Azure storage account using az cli?" print(f"User: {query1}") print(f"{agent.name}: ", end="") - async for update in handle_approvals_with_thread_streaming(query1, agent, thread): + async for update in handle_approvals_with_session_streaming(query1, agent, session): print(update, end="") print("\n") print("\n=======================================\n") @@ -238,7 +238,7 @@ async def run_hosted_mcp_with_thread_streaming() -> None: query2 = "What is Microsoft Agent Framework?" print(f"User: {query2}") print(f"{agent.name}: ", end="") - async for update in handle_approvals_with_thread_streaming(query2, agent, thread): + async for update in handle_approvals_with_session_streaming(query2, agent, session): print(update, end="") print("\n") @@ -247,9 +247,9 @@ async def main() -> None: print("=== OpenAI Responses Client Agent with Hosted Mcp Tools Examples ===\n") await run_hosted_mcp_without_approval() - await run_hosted_mcp_without_thread_and_specific_approval() - await run_hosted_mcp_with_thread() - await run_hosted_mcp_with_thread_streaming() + await run_hosted_mcp_without_session_and_specific_approval() + await run_hosted_mcp_with_session() + await run_hosted_mcp_with_session_streaming() if __name__ == "__main__": diff --git a/python/samples/02-agents/providers/azure_openai/azure_responses_client_with_thread.py b/python/samples/02-agents/providers/azure_openai/azure_responses_client_with_thread.py index f27af60e88..2de40871a4 100644 --- a/python/samples/02-agents/providers/azure_openai/azure_responses_client_with_thread.py +++ b/python/samples/02-agents/providers/azure_openai/azure_responses_client_with_thread.py @@ -4,16 +4,16 @@ from random import randint from typing import Annotated -from agent_framework import Agent, AgentThread, tool +from agent_framework import Agent, AgentSession, tool from agent_framework.azure import AzureOpenAIResponsesClient from azure.identity import AzureCliCredential from pydantic import Field """ -Azure OpenAI Responses Client with Thread Management Example +Azure OpenAI Responses Client with Session Management Example -This sample demonstrates thread management with Azure OpenAI Responses Client, comparing -automatic thread creation with explicit thread management for persistent context. +This sample demonstrates session management with Azure OpenAI Responses Client, comparing +automatic session creation with explicit session management for persistent context. """ @@ -29,9 +29,9 @@ def get_weather( return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." -async def example_with_automatic_thread_creation() -> None: - """Example showing automatic thread creation.""" - print("=== Automatic Thread Creation Example ===") +async def example_with_automatic_session_creation() -> None: + """Example showing automatic session creation.""" + print("=== Automatic Session Creation Example ===") # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. @@ -41,26 +41,26 @@ async def example_with_automatic_thread_creation() -> None: tools=get_weather, ) - # First conversation - no thread provided, will be created automatically + # First conversation - no session provided, will be created automatically query1 = "What's the weather like in Seattle?" print(f"User: {query1}") result1 = await agent.run(query1) print(f"Agent: {result1.text}") - # Second conversation - still no thread provided, will create another new thread + # Second conversation - still no session provided, will create another new session query2 = "What was the last city I asked about?" print(f"\nUser: {query2}") result2 = await agent.run(query2) print(f"Agent: {result2.text}") - print("Note: Each call creates a separate thread, so the agent doesn't remember previous context.\n") + print("Note: Each call creates a separate session, so the agent doesn't remember previous context.\n") -async def example_with_thread_persistence_in_memory() -> None: +async def example_with_session_persistence_in_memory() -> None: """ - Example showing thread persistence across multiple conversations. + Example showing session persistence across multiple conversations. In this example, messages are stored in-memory. """ - print("=== Thread Persistence Example (In-Memory) ===") + print("=== Session Persistence Example (In-Memory) ===") # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. @@ -70,38 +70,38 @@ async def example_with_thread_persistence_in_memory() -> None: tools=get_weather, ) - # Create a new thread that will be reused - thread = agent.get_new_thread() + # Create a new session that will be reused + session = agent.create_session() # First conversation query1 = "What's the weather like in Tokyo?" print(f"User: {query1}") - result1 = await agent.run(query1, thread=thread) + result1 = await agent.run(query1, session=session) print(f"Agent: {result1.text}") - # Second conversation using the same thread - maintains context + # Second conversation using the same session - maintains context query2 = "How about London?" print(f"\nUser: {query2}") - result2 = await agent.run(query2, thread=thread) + result2 = await agent.run(query2, session=session) print(f"Agent: {result2.text}") # Third conversation - agent should remember both previous cities query3 = "Which of the cities I asked about has better weather?" print(f"\nUser: {query3}") - result3 = await agent.run(query3, thread=thread) + result3 = await agent.run(query3, session=session) print(f"Agent: {result3.text}") - print("Note: The agent remembers context from previous messages in the same thread.\n") + print("Note: The agent remembers context from previous messages in the same session.\n") -async def example_with_existing_thread_id() -> None: +async def example_with_existing_session_id() -> None: """ - Example showing how to work with an existing thread ID from the service. + Example showing how to work with an existing session ID from the service. In this example, messages are stored on the server using Azure OpenAI conversation state. """ - print("=== Existing Thread ID Example ===") + print("=== Existing Session ID Example ===") - # First, create a conversation and capture the thread ID - existing_thread_id = None + # First, create a conversation and capture the session ID + existing_session_id = None # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. @@ -111,21 +111,21 @@ async def example_with_existing_thread_id() -> None: tools=get_weather, ) - # Start a conversation and get the thread ID - thread = agent.get_new_thread() + # Start a conversation and get the session ID + session = agent.create_session() query1 = "What's the weather in Paris?" print(f"User: {query1}") # Enable Azure OpenAI conversation state by setting `store` parameter to True - result1 = await agent.run(query1, thread=thread, store=True) + result1 = await agent.run(query1, session=session, store=True) print(f"Agent: {result1.text}") - # The thread ID is set after the first response - existing_thread_id = thread.service_thread_id - print(f"Thread ID: {existing_thread_id}") + # The session ID is set after the first response + existing_session_id = session.service_session_id + print(f"Session ID: {existing_session_id}") - if existing_thread_id: - print("\n--- Continuing with the same thread ID in a new agent instance ---") + if existing_session_id: + print("\n--- Continuing with the same session ID in a new agent instance ---") agent = Agent( client=AzureOpenAIResponsesClient(credential=AzureCliCredential()), @@ -133,22 +133,22 @@ async def example_with_existing_thread_id() -> None: tools=get_weather, ) - # Create a thread with the existing ID - thread = AgentThread(service_thread_id=existing_thread_id) + # Create a session with the existing ID + session = AgentSession(service_session_id=existing_session_id) query2 = "What was the last city I asked about?" print(f"User: {query2}") - result2 = await agent.run(query2, thread=thread, store=True) + result2 = await agent.run(query2, session=session, store=True) print(f"Agent: {result2.text}") - print("Note: The agent continues the conversation from the previous thread by using thread ID.\n") + print("Note: The agent continues the conversation from the previous session by using session ID.\n") async def main() -> None: - print("=== Azure OpenAI Response Client Agent Thread Management Examples ===\n") + print("=== Azure OpenAI Response Client Agent Session Management Examples ===\n") - await example_with_automatic_thread_creation() - await example_with_thread_persistence_in_memory() - await example_with_existing_thread_id() + await example_with_automatic_session_creation() + await example_with_session_persistence_in_memory() + await example_with_existing_session_id() if __name__ == "__main__": diff --git a/python/samples/02-agents/providers/custom/custom_agent.py b/python/samples/02-agents/providers/custom/custom_agent.py index 51fb2452c8..14626388bd 100644 --- a/python/samples/02-agents/providers/custom/custom_agent.py +++ b/python/samples/02-agents/providers/custom/custom_agent.py @@ -7,7 +7,7 @@ from agent_framework import ( AgentResponse, AgentResponseUpdate, - AgentThread, + AgentSession, BaseAgent, Content, Message, @@ -60,7 +60,7 @@ def run( messages: str | Message | list[str] | list[Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> "AsyncIterable[AgentResponseUpdate] | asyncio.Future[AgentResponse]": """Execute the agent and return a response. @@ -68,7 +68,7 @@ def run( Args: messages: The message(s) to process. stream: If True, return an async iterable of updates. If False, return an awaitable response. - thread: The conversation thread (optional). + session: The conversation session (optional). **kwargs: Additional keyword arguments. Returns: @@ -76,14 +76,14 @@ def run( When stream=True: An async iterable of AgentResponseUpdate objects. """ if stream: - return self._run_stream(messages=messages, thread=thread, **kwargs) - return self._run(messages=messages, thread=thread, **kwargs) + return self._run_stream(messages=messages, session=session, **kwargs) + return self._run(messages=messages, session=session, **kwargs) async def _run( self, messages: str | Message | list[str] | list[Message] | None = None, *, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> AgentResponse: """Non-streaming implementation.""" @@ -105,9 +105,9 @@ async def _run( response_message = Message(role=Role.ASSISTANT, contents=[Content.from_text(text=echo_text)]) - # Notify the thread of new messages if provided - if thread is not None: - await self._notify_thread_of_new_messages(thread, normalized_messages, response_message) + # Notify the session of new messages if provided + if session is not None: + await self._notify_thread_of_new_messages(session, normalized_messages, response_message) return AgentResponse(messages=[response_message]) @@ -115,7 +115,7 @@ async def _run_stream( self, messages: str | Message | list[str] | list[Message] | None = None, *, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> AsyncIterable[AgentResponseUpdate]: """Streaming implementation.""" @@ -146,10 +146,10 @@ async def _run_stream( # Small delay to simulate streaming await asyncio.sleep(0.1) - # Notify the thread of the complete response if provided - if thread is not None: + # Notify the session of the complete response if provided + if session is not None: complete_response = Message(role=Role.ASSISTANT, contents=[Content.from_text(text=response_text)]) - await self._notify_thread_of_new_messages(thread, normalized_messages, complete_response) + await self._notify_thread_of_new_messages(session, normalized_messages, complete_response) async def main() -> None: @@ -180,26 +180,27 @@ async def main() -> None: print(chunk.text, end="", flush=True) print() - # Example with threads - print("\n--- Using Custom Agent with Thread ---") - thread = echo_agent.get_new_thread() + # Example with sessions + print("\n--- Using Custom Agent with Session ---") + session = echo_agent.create_session() # First message - result1 = await echo_agent.run("First message", thread=thread) + result1 = await echo_agent.run("First message", session=session) print("User: First message") print(f"Agent: {result1.messages[0].text}") # Second message in same thread - result2 = await echo_agent.run("Second message", thread=thread) + result2 = await echo_agent.run("Second message", session=session) print("User: Second message") print(f"Agent: {result2.messages[0].text}") # Check conversation history - if thread.message_store: - messages = await thread.message_store.list_messages() - print(f"\nThread contains {len(messages)} messages in history") + memory_state = session.state.get("memory", {}) + messages = memory_state.get("messages", []) + if messages: + print(f"\nSession contains {len(messages)} messages in history") else: - print("\nThread has no message store configured") + print("\nSession has no messages stored") if __name__ == "__main__": diff --git a/python/samples/02-agents/providers/github_copilot/github_copilot_with_session.py b/python/samples/02-agents/providers/github_copilot/github_copilot_with_session.py index e103cdbde3..c07395ba6a 100644 --- a/python/samples/02-agents/providers/github_copilot/github_copilot_with_session.py +++ b/python/samples/02-agents/providers/github_copilot/github_copilot_with_session.py @@ -61,31 +61,31 @@ async def example_with_session_persistence() -> None: ) async with agent: - # Create a thread to maintain conversation context - thread = agent.get_new_thread() + # Create a session to maintain conversation context + session = agent.create_session() # First query query1 = "What's the weather like in Tokyo?" print(f"User: {query1}") - result1 = await agent.run(query1, thread=thread) + result1 = await agent.run(query1, session=session) print(f"Agent: {result1}") # Second query - using same thread maintains context query2 = "How about London?" print(f"\nUser: {query2}") - result2 = await agent.run(query2, thread=thread) + result2 = await agent.run(query2, session=session) print(f"Agent: {result2}") # Third query - agent should remember both previous cities query3 = "Which of the cities I asked about has better weather?" print(f"\nUser: {query3}") - result3 = await agent.run(query3, thread=thread) + result3 = await agent.run(query3, session=session) print(f"Agent: {result3}") print("Note: The agent remembers context from previous messages in the same session.\n") async def example_with_existing_session_id() -> None: - """Resume session in new agent instance using service_thread_id.""" + """Resume session in new agent instance using service_session_id.""" print("=== Existing Session ID Example ===") existing_session_id = None @@ -97,15 +97,15 @@ async def example_with_existing_session_id() -> None: ) async with agent1: - thread = agent1.get_new_thread() + session = agent1.create_session() query1 = "What's the weather in Paris?" print(f"User: {query1}") - result1 = await agent1.run(query1, thread=thread) + result1 = await agent1.run(query1, session=session) print(f"Agent: {result1}") # Capture the session ID for later use - existing_session_id = thread.service_thread_id + existing_session_id = session.service_session_id print(f"Session ID: {existing_session_id}") if existing_session_id: @@ -118,12 +118,12 @@ async def example_with_existing_session_id() -> None: ) async with agent2: - # Create thread with existing session ID - thread = agent2.get_new_thread(service_thread_id=existing_session_id) + # Create session with existing session ID + session = agent2.create_session(service_session_id=existing_session_id) query2 = "What was the last city I asked about?" print(f"User: {query2}") - result2 = await agent2.run(query2, thread=thread) + result2 = await agent2.run(query2, session=session) print(f"Agent: {result2}") print("Note: The agent continues the conversation using the session ID.\n") diff --git a/python/samples/02-agents/providers/openai/openai_assistants_with_thread.py b/python/samples/02-agents/providers/openai/openai_assistants_with_thread.py index 2214736dc0..155b5d6a73 100644 --- a/python/samples/02-agents/providers/openai/openai_assistants_with_thread.py +++ b/python/samples/02-agents/providers/openai/openai_assistants_with_thread.py @@ -5,16 +5,16 @@ from random import randint from typing import Annotated -from agent_framework import AgentThread, tool +from agent_framework import AgentSession, tool from agent_framework.openai import OpenAIAssistantProvider from openai import AsyncOpenAI from pydantic import Field """ -OpenAI Assistants with Thread Management Example +OpenAI Assistants with Session Management Example -This sample demonstrates thread management with OpenAI Assistants, showing -persistent conversation threads and context preservation across interactions. +This sample demonstrates session management with OpenAI Assistants, showing +persistent conversation sessions and context preservation across interactions. """ @@ -30,9 +30,9 @@ def get_weather( return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}C." -async def example_with_automatic_thread_creation() -> None: - """Example showing automatic thread creation (service-managed thread).""" - print("=== Automatic Thread Creation Example ===") +async def example_with_automatic_session_creation() -> None: + """Example showing automatic session creation (service-managed session).""" + print("=== Automatic Session Creation Example ===") client = AsyncOpenAI() provider = OpenAIAssistantProvider(client) @@ -45,26 +45,26 @@ async def example_with_automatic_thread_creation() -> None: ) try: - # First conversation - no thread provided, will be created automatically + # First conversation - no session provided, will be created automatically query1 = "What's the weather like in Seattle?" print(f"User: {query1}") result1 = await agent.run(query1) print(f"Agent: {result1.text}") - # Second conversation - still no thread provided, will create another new thread + # Second conversation - still no session provided, will create another new session query2 = "What was the last city I asked about?" print(f"\nUser: {query2}") result2 = await agent.run(query2) print(f"Agent: {result2.text}") - print("Note: Each call creates a separate thread, so the agent doesn't remember previous context.\n") + print("Note: Each call creates a separate session, so the agent doesn't remember previous context.\n") finally: await client.beta.assistants.delete(agent.id) -async def example_with_thread_persistence() -> None: - """Example showing thread persistence across multiple conversations.""" - print("=== Thread Persistence Example ===") - print("Using the same thread across multiple conversations to maintain context.\n") +async def example_with_session_persistence() -> None: + """Example showing session persistence across multiple conversations.""" + print("=== Session Persistence Example ===") + print("Using the same session across multiple conversations to maintain context.\n") client = AsyncOpenAI() provider = OpenAIAssistantProvider(client) @@ -77,41 +77,41 @@ async def example_with_thread_persistence() -> None: ) try: - # Create a new thread that will be reused - thread = agent.get_new_thread() + # Create a new session that will be reused + session = agent.create_session() # First conversation query1 = "What's the weather like in Tokyo?" print(f"User: {query1}") - result1 = await agent.run(query1, thread=thread) + result1 = await agent.run(query1, session=session) print(f"Agent: {result1.text}") - # Second conversation using the same thread - maintains context + # Second conversation using the same session - maintains context query2 = "How about London?" print(f"\nUser: {query2}") - result2 = await agent.run(query2, thread=thread) + result2 = await agent.run(query2, session=session) print(f"Agent: {result2.text}") # Third conversation - agent should remember both previous cities query3 = "Which of the cities I asked about has better weather?" print(f"\nUser: {query3}") - result3 = await agent.run(query3, thread=thread) + result3 = await agent.run(query3, session=session) print(f"Agent: {result3.text}") - print("Note: The agent remembers context from previous messages in the same thread.\n") + print("Note: The agent remembers context from previous messages in the same session.\n") finally: await client.beta.assistants.delete(agent.id) -async def example_with_existing_thread_id() -> None: - """Example showing how to work with an existing thread ID from the service.""" - print("=== Existing Thread ID Example ===") - print("Using a specific thread ID to continue an existing conversation.\n") +async def example_with_existing_session_id() -> None: + """Example showing how to work with an existing session ID from the service.""" + print("=== Existing Session ID Example ===") + print("Using a specific session ID to continue an existing conversation.\n") client = AsyncOpenAI() provider = OpenAIAssistantProvider(client) - # First, create a conversation and capture the thread ID - existing_thread_id = None + # First, create a conversation and capture the session ID + existing_session_id = None assistant_id = None agent = await provider.create_agent( @@ -123,19 +123,19 @@ async def example_with_existing_thread_id() -> None: assistant_id = agent.id try: - # Start a conversation and get the thread ID - thread = agent.get_new_thread() + # Start a conversation and get the session ID + session = agent.create_session() query1 = "What's the weather in Paris?" print(f"User: {query1}") - result1 = await agent.run(query1, thread=thread) + result1 = await agent.run(query1, session=session) print(f"Agent: {result1.text}") - # The thread ID is set after the first response - existing_thread_id = thread.service_thread_id - print(f"Thread ID: {existing_thread_id}") + # The session ID is set after the first response + existing_session_id = session.service_session_id + print(f"Session ID: {existing_session_id}") - if existing_thread_id: - print("\n--- Continuing with the same thread ID using get_agent ---") + if existing_session_id: + print("\n--- Continuing with the same session ID using get_agent ---") # Get the existing assistant by ID agent2 = await provider.get_agent( @@ -143,25 +143,25 @@ async def example_with_existing_thread_id() -> None: tools=[get_weather], # Must provide function implementations ) - # Create a thread with the existing ID - thread = AgentThread(service_thread_id=existing_thread_id) + # Create a session with the existing ID + session = AgentSession(service_session_id=existing_session_id) query2 = "What was the last city I asked about?" print(f"User: {query2}") - result2 = await agent2.run(query2, thread=thread) + result2 = await agent2.run(query2, session=session) print(f"Agent: {result2.text}") - print("Note: The agent continues the conversation from the previous thread.\n") + print("Note: The agent continues the conversation from the previous session.\n") finally: if assistant_id: await client.beta.assistants.delete(assistant_id) async def main() -> None: - print("=== OpenAI Assistants Provider Thread Management Examples ===\n") + print("=== OpenAI Assistants Provider Session Management Examples ===\n") - await example_with_automatic_thread_creation() - await example_with_thread_persistence() - await example_with_existing_thread_id() + await example_with_automatic_session_creation() + await example_with_session_persistence() + await example_with_existing_session_id() if __name__ == "__main__": diff --git a/python/samples/02-agents/providers/openai/openai_chat_client_with_thread.py b/python/samples/02-agents/providers/openai/openai_chat_client_with_thread.py index 9486baa5f3..ea225d80f6 100644 --- a/python/samples/02-agents/providers/openai/openai_chat_client_with_thread.py +++ b/python/samples/02-agents/providers/openai/openai_chat_client_with_thread.py @@ -4,15 +4,15 @@ from random import randint from typing import Annotated -from agent_framework import Agent, AgentThread, ChatMessageStore, tool +from agent_framework import Agent, AgentSession, tool from agent_framework.openai import OpenAIChatClient from pydantic import Field """ -OpenAI Chat Client with Thread Management Example +OpenAI Chat Client with Session Management Example -This sample demonstrates thread management with OpenAI Chat Client, showing -conversation threads and message history preservation across interactions. +This sample demonstrates session management with OpenAI Chat Client, showing +conversation sessions and message history preservation across interactions. """ @@ -28,9 +28,9 @@ def get_weather( return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." -async def example_with_automatic_thread_creation() -> None: - """Example showing automatic thread creation (service-managed thread).""" - print("=== Automatic Thread Creation Example ===") +async def example_with_automatic_session_creation() -> None: + """Example showing automatic session creation (service-managed session).""" + print("=== Automatic Session Creation Example ===") agent = Agent( client=OpenAIChatClient(), @@ -38,24 +38,24 @@ async def example_with_automatic_thread_creation() -> None: tools=get_weather, ) - # First conversation - no thread provided, will be created automatically + # First conversation - no session provided, will be created automatically query1 = "What's the weather like in Seattle?" print(f"User: {query1}") result1 = await agent.run(query1) print(f"Agent: {result1.text}") - # Second conversation - still no thread provided, will create another new thread + # Second conversation - still no session provided, will create another new session query2 = "What was the last city I asked about?" print(f"\nUser: {query2}") result2 = await agent.run(query2) print(f"Agent: {result2.text}") - print("Note: Each call creates a separate thread, so the agent doesn't remember previous context.\n") + print("Note: Each call creates a separate session, so the agent doesn't remember previous context.\n") -async def example_with_thread_persistence() -> None: - """Example showing thread persistence across multiple conversations.""" - print("=== Thread Persistence Example ===") - print("Using the same thread across multiple conversations to maintain context.\n") +async def example_with_session_persistence() -> None: + """Example showing session persistence across multiple conversations.""" + print("=== Session Persistence Example ===") + print("Using the same session across multiple conversations to maintain context.\n") agent = Agent( client=OpenAIChatClient(), @@ -63,32 +63,32 @@ async def example_with_thread_persistence() -> None: tools=get_weather, ) - # Create a new thread that will be reused - thread = agent.get_new_thread() + # Create a new session that will be reused + session = agent.create_session() # First conversation query1 = "What's the weather like in Tokyo?" print(f"User: {query1}") - result1 = await agent.run(query1, thread=thread) + result1 = await agent.run(query1, session=session) print(f"Agent: {result1.text}") - # Second conversation using the same thread - maintains context + # Second conversation using the same session - maintains context query2 = "How about London?" print(f"\nUser: {query2}") - result2 = await agent.run(query2, thread=thread) + result2 = await agent.run(query2, session=session) print(f"Agent: {result2.text}") # Third conversation - agent should remember both previous cities query3 = "Which of the cities I asked about has better weather?" print(f"\nUser: {query3}") - result3 = await agent.run(query3, thread=thread) + result3 = await agent.run(query3, session=session) print(f"Agent: {result3.text}") - print("Note: The agent remembers context from previous messages in the same thread.\n") + print("Note: The agent remembers context from previous messages in the same session.\n") -async def example_with_existing_thread_messages() -> None: - """Example showing how to work with existing thread messages for OpenAI.""" - print("=== Existing Thread Messages Example ===") +async def example_with_existing_session_messages() -> None: + """Example showing how to work with existing session messages for OpenAI.""" + print("=== Existing Session Messages Example ===") agent = Agent( client=OpenAIChatClient(), @@ -97,54 +97,54 @@ async def example_with_existing_thread_messages() -> None: ) # Start a conversation and build up message history - thread = agent.get_new_thread() + session = agent.create_session() query1 = "What's the weather in Paris?" print(f"User: {query1}") - result1 = await agent.run(query1, thread=thread) + result1 = await agent.run(query1, session=session) print(f"Agent: {result1.text}") - # The thread now contains the conversation history in memory - if thread.message_store: - messages = await thread.message_store.list_messages() - print(f"Thread contains {len(messages or [])} messages") + # The session now contains the conversation history in memory + if session.message_store: + messages = await session.message_store.list_messages() + print(f"Session contains {len(messages or [])} messages") - print("\n--- Continuing with the same thread in a new agent instance ---") + print("\n--- Continuing with the same session in a new agent instance ---") - # Create a new agent instance but use the existing thread with its message history + # Create a new agent instance but use the existing session with its message history new_agent = Agent( client=OpenAIChatClient(), instructions="You are a helpful weather agent.", tools=get_weather, ) - # Use the same thread object which contains the conversation history + # Use the same session object which contains the conversation history query2 = "What was the last city I asked about?" print(f"User: {query2}") - result2 = await new_agent.run(query2, thread=thread) + result2 = await new_agent.run(query2, session=session) print(f"Agent: {result2.text}") print("Note: The agent continues the conversation using the local message history.\n") - print("\n--- Alternative: Creating a new thread from existing messages ---") + print("\n--- Alternative: Creating a new session from existing messages ---") - # You can also create a new thread from existing messages - messages = await thread.message_store.list_messages() if thread.message_store else [] + # You can also create a new session from existing messages + messages = await session.message_store.list_messages() if session.message_store else [] - new_thread = AgentThread(message_store=ChatMessageStore(messages)) + new_session = AgentSession() query3 = "How does the Paris weather compare to London?" print(f"User: {query3}") - result3 = await new_agent.run(query3, thread=new_thread) + result3 = await new_agent.run(query3, session=new_session) print(f"Agent: {result3.text}") - print("Note: This creates a new thread with the same conversation history.\n") + print("Note: This creates a new session with the same conversation history.\n") async def main() -> None: - print("=== OpenAI Chat Client Agent Thread Management Examples ===\n") + print("=== OpenAI Chat Client Agent Session Management Examples ===\n") - await example_with_automatic_thread_creation() - await example_with_thread_persistence() - await example_with_existing_thread_messages() + await example_with_automatic_session_creation() + await example_with_session_persistence() + await example_with_existing_session_messages() if __name__ == "__main__": diff --git a/python/samples/02-agents/providers/openai/openai_responses_client_with_hosted_mcp.py b/python/samples/02-agents/providers/openai/openai_responses_client_with_hosted_mcp.py index f934cd0820..6c27ea36a9 100644 --- a/python/samples/02-agents/providers/openai/openai_responses_client_with_hosted_mcp.py +++ b/python/samples/02-agents/providers/openai/openai_responses_client_with_hosted_mcp.py @@ -14,11 +14,11 @@ """ if TYPE_CHECKING: - from agent_framework import AgentThread, SupportsAgentRun + from agent_framework import AgentSession, SupportsAgentRun -async def handle_approvals_without_thread(query: str, agent: "SupportsAgentRun"): - """When we don't have a thread, we need to ensure we return with the input, approval request and approval.""" +async def handle_approvals_without_session(query: str, agent: "SupportsAgentRun"): + """When we don't have a session, we need to ensure we return with the input, approval request and approval.""" from agent_framework import Message result = await agent.run(query) @@ -42,11 +42,11 @@ async def handle_approvals_without_thread(query: str, agent: "SupportsAgentRun") return result -async def handle_approvals_with_thread(query: str, agent: "SupportsAgentRun", thread: "AgentThread"): - """Here we let the thread deal with the previous responses, and we just rerun with the approval.""" +async def handle_approvals_with_session(query: str, agent: "SupportsAgentRun", session: "AgentSession"): + """Here we let the session deal with the previous responses, and we just rerun with the approval.""" from agent_framework import Message - result = await agent.run(query, thread=thread, store=True) + result = await agent.run(query, session=session, store=True) while len(result.user_input_requests) > 0: new_input: list[Any] = [] for user_input_needed in result.user_input_requests: @@ -61,12 +61,12 @@ async def handle_approvals_with_thread(query: str, agent: "SupportsAgentRun", th contents=[user_input_needed.to_function_approval_response(user_approval.lower() == "y")], ) ) - result = await agent.run(new_input, thread=thread, store=True) + result = await agent.run(new_input, session=session, store=True) return result -async def handle_approvals_with_thread_streaming(query: str, agent: "SupportsAgentRun", thread: "AgentThread"): - """Here we let the thread deal with the previous responses, and we just rerun with the approval.""" +async def handle_approvals_with_session_streaming(query: str, agent: "SupportsAgentRun", session: "AgentSession"): + """Here we let the session deal with the previous responses, and we just rerun with the approval.""" from agent_framework import Message new_input: list[Message] = [] @@ -74,7 +74,7 @@ async def handle_approvals_with_thread_streaming(query: str, agent: "SupportsAge while new_input_added: new_input_added = False new_input.append(Message(role="user", text=query)) - async for update in agent.run(new_input, thread=thread, stream=True, options={"store": True}): + async for update in agent.run(new_input, session=session, stream=True, options={"store": True}): if update.user_input_requests: for user_input_needed in update.user_input_requests: print( @@ -93,9 +93,9 @@ async def handle_approvals_with_thread_streaming(query: str, agent: "SupportsAge yield update -async def run_hosted_mcp_without_thread_and_specific_approval() -> None: - """Example showing Mcp Tools with approvals without using a thread.""" - print("=== Mcp with approvals and without thread ===") +async def run_hosted_mcp_without_session_and_specific_approval() -> None: + """Example showing Mcp Tools with approvals without using a session.""" + print("=== Mcp with approvals and without session ===") client = OpenAIResponsesClient() # Create MCP tool with specific approval mode @@ -116,13 +116,13 @@ async def run_hosted_mcp_without_thread_and_specific_approval() -> None: # First query query1 = "How to create an Azure storage account using az cli?" print(f"User: {query1}") - result1 = await handle_approvals_without_thread(query1, agent) + result1 = await handle_approvals_without_session(query1, agent) print(f"{agent.name}: {result1}\n") print("\n=======================================\n") # Second query query2 = "What is Microsoft Agent Framework?" print(f"User: {query2}") - result2 = await handle_approvals_without_thread(query2, agent) + result2 = await handle_approvals_without_session(query2, agent) print(f"{agent.name}: {result2}\n") @@ -148,19 +148,19 @@ async def run_hosted_mcp_without_approval() -> None: # First query query1 = "How to create an Azure storage account using az cli?" print(f"User: {query1}") - result1 = await handle_approvals_without_thread(query1, agent) + result1 = await handle_approvals_without_session(query1, agent) print(f"{agent.name}: {result1}\n") print("\n=======================================\n") # Second query query2 = "What is Microsoft Agent Framework?" print(f"User: {query2}") - result2 = await handle_approvals_without_thread(query2, agent) + result2 = await handle_approvals_without_session(query2, agent) print(f"{agent.name}: {result2}\n") -async def run_hosted_mcp_with_thread() -> None: - """Example showing Mcp Tools with approvals using a thread.""" - print("=== Mcp with approvals and with thread ===") +async def run_hosted_mcp_with_session() -> None: + """Example showing Mcp Tools with approvals using a session.""" + print("=== Mcp with approvals and with session ===") client = OpenAIResponsesClient() # Create MCP tool that always requires approval @@ -178,22 +178,22 @@ async def run_hosted_mcp_with_thread() -> None: tools=mcp_tool, ) as agent: # First query - thread = agent.get_new_thread() + session = agent.create_session() query1 = "How to create an Azure storage account using az cli?" print(f"User: {query1}") - result1 = await handle_approvals_with_thread(query1, agent, thread) + result1 = await handle_approvals_with_session(query1, agent, session) print(f"{agent.name}: {result1}\n") print("\n=======================================\n") # Second query query2 = "What is Microsoft Agent Framework?" print(f"User: {query2}") - result2 = await handle_approvals_with_thread(query2, agent, thread) + result2 = await handle_approvals_with_session(query2, agent, session) print(f"{agent.name}: {result2}\n") -async def run_hosted_mcp_with_thread_streaming() -> None: - """Example showing Mcp Tools with approvals using a thread.""" - print("=== Mcp with approvals and with thread ===") +async def run_hosted_mcp_with_session_streaming() -> None: + """Example showing Mcp Tools with approvals using a session.""" + print("=== Mcp with approvals and with session ===") client = OpenAIResponsesClient() # Create MCP tool that always requires approval @@ -211,11 +211,11 @@ async def run_hosted_mcp_with_thread_streaming() -> None: tools=mcp_tool, ) as agent: # First query - thread = agent.get_new_thread() + session = agent.create_session() query1 = "How to create an Azure storage account using az cli?" print(f"User: {query1}") print(f"{agent.name}: ", end="") - async for update in handle_approvals_with_thread_streaming(query1, agent, thread): + async for update in handle_approvals_with_session_streaming(query1, agent, session): print(update, end="") print("\n") print("\n=======================================\n") @@ -223,7 +223,7 @@ async def run_hosted_mcp_with_thread_streaming() -> None: query2 = "What is Microsoft Agent Framework?" print(f"User: {query2}") print(f"{agent.name}: ", end="") - async for update in handle_approvals_with_thread_streaming(query2, agent, thread): + async for update in handle_approvals_with_session_streaming(query2, agent, session): print(update, end="") print("\n") @@ -232,9 +232,9 @@ async def main() -> None: print("=== OpenAI Responses Client Agent with Hosted Mcp Tools Examples ===\n") await run_hosted_mcp_without_approval() - await run_hosted_mcp_without_thread_and_specific_approval() - await run_hosted_mcp_with_thread() - await run_hosted_mcp_with_thread_streaming() + await run_hosted_mcp_without_session_and_specific_approval() + await run_hosted_mcp_with_session() + await run_hosted_mcp_with_session_streaming() if __name__ == "__main__": diff --git a/python/samples/02-agents/providers/openai/openai_responses_client_with_thread.py b/python/samples/02-agents/providers/openai/openai_responses_client_with_thread.py index f3e6024e17..4000db96cc 100644 --- a/python/samples/02-agents/providers/openai/openai_responses_client_with_thread.py +++ b/python/samples/02-agents/providers/openai/openai_responses_client_with_thread.py @@ -4,14 +4,14 @@ from random import randint from typing import Annotated -from agent_framework import Agent, AgentThread, tool +from agent_framework import Agent, AgentSession, tool from agent_framework.openai import OpenAIResponsesClient from pydantic import Field """ -OpenAI Responses Client with Thread Management Example +OpenAI Responses Client with Session Management Example -This sample demonstrates thread management with OpenAI Responses Client, showing +This sample demonstrates session management with OpenAI Responses Client, showing persistent conversation context and simplified response handling. """ @@ -28,9 +28,9 @@ def get_weather( return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." -async def example_with_automatic_thread_creation() -> None: - """Example showing automatic thread creation.""" - print("=== Automatic Thread Creation Example ===") +async def example_with_automatic_session_creation() -> None: + """Example showing automatic session creation.""" + print("=== Automatic Session Creation Example ===") agent = Agent( client=OpenAIResponsesClient(), @@ -38,26 +38,26 @@ async def example_with_automatic_thread_creation() -> None: tools=get_weather, ) - # First conversation - no thread provided, will be created automatically + # First conversation - no session provided, will be created automatically query1 = "What's the weather like in Seattle?" print(f"User: {query1}") result1 = await agent.run(query1) print(f"Agent: {result1.text}") - # Second conversation - still no thread provided, will create another new thread + # Second conversation - still no session provided, will create another new session query2 = "What was the last city I asked about?" print(f"\nUser: {query2}") result2 = await agent.run(query2) print(f"Agent: {result2.text}") - print("Note: Each call creates a separate thread, so the agent doesn't remember previous context.\n") + print("Note: Each call creates a separate session, so the agent doesn't remember previous context.\n") -async def example_with_thread_persistence_in_memory() -> None: +async def example_with_session_persistence_in_memory() -> None: """ - Example showing thread persistence across multiple conversations. + Example showing session persistence across multiple conversations. In this example, messages are stored in-memory. """ - print("=== Thread Persistence Example (In-Memory) ===") + print("=== Session Persistence Example (In-Memory) ===") agent = Agent( client=OpenAIResponsesClient(), @@ -65,38 +65,38 @@ async def example_with_thread_persistence_in_memory() -> None: tools=get_weather, ) - # Create a new thread that will be reused - thread = agent.get_new_thread() + # Create a new session that will be reused + session = agent.create_session() # First conversation query1 = "What's the weather like in Tokyo?" print(f"User: {query1}") - result1 = await agent.run(query1, thread=thread, store=False) + result1 = await agent.run(query1, session=session, store=False) print(f"Agent: {result1.text}") - # Second conversation using the same thread - maintains context + # Second conversation using the same session - maintains context query2 = "How about London?" print(f"\nUser: {query2}") - result2 = await agent.run(query2, thread=thread, store=False) + result2 = await agent.run(query2, session=session, store=False) print(f"Agent: {result2.text}") # Third conversation - agent should remember both previous cities query3 = "Which of the cities I asked about has better weather?" print(f"\nUser: {query3}") - result3 = await agent.run(query3, thread=thread, store=False) + result3 = await agent.run(query3, session=session, store=False) print(f"Agent: {result3.text}") - print("Note: The agent remembers context from previous messages in the same thread.\n") + print("Note: The agent remembers context from previous messages in the same session.\n") -async def example_with_existing_thread_id() -> None: +async def example_with_existing_session_id() -> None: """ - Example showing how to work with an existing thread ID from the service. + Example showing how to work with an existing session ID from the service. In this example, messages are stored on the server using OpenAI conversation state. """ - print("=== Existing Thread ID Example ===") + print("=== Existing Session ID Example ===") - # First, create a conversation and capture the thread ID - existing_thread_id = None + # First, create a conversation and capture the session ID + existing_session_id = None agent = Agent( client=OpenAIResponsesClient(), @@ -104,20 +104,20 @@ async def example_with_existing_thread_id() -> None: tools=get_weather, ) - # Start a conversation and get the thread ID - thread = agent.get_new_thread() + # Start a conversation and get the session ID + session = agent.create_session() query1 = "What's the weather in Paris?" print(f"User: {query1}") - result1 = await agent.run(query1, thread=thread) + result1 = await agent.run(query1, session=session) print(f"Agent: {result1.text}") - # The thread ID is set after the first response - existing_thread_id = thread.service_thread_id - print(f"Thread ID: {existing_thread_id}") + # The session ID is set after the first response + existing_session_id = session.service_session_id + print(f"Session ID: {existing_session_id}") - if existing_thread_id: - print("\n--- Continuing with the same thread ID in a new agent instance ---") + if existing_session_id: + print("\n--- Continuing with the same session ID in a new agent instance ---") agent = Agent( client=OpenAIResponsesClient(), @@ -125,22 +125,22 @@ async def example_with_existing_thread_id() -> None: tools=get_weather, ) - # Create a thread with the existing ID - thread = AgentThread(service_thread_id=existing_thread_id) + # Create a session with the existing ID + session = AgentSession(service_session_id=existing_session_id) query2 = "What was the last city I asked about?" print(f"User: {query2}") - result2 = await agent.run(query2, thread=thread) + result2 = await agent.run(query2, session=session) print(f"Agent: {result2.text}") - print("Note: The agent continues the conversation from the previous thread by using thread ID.\n") + print("Note: The agent continues the conversation from the previous session by using session ID.\n") async def main() -> None: - print("=== OpenAI Response Client Agent Thread Management Examples ===\n") + print("=== OpenAI Response Client Agent Session Management Examples ===\n") - await example_with_automatic_thread_creation() - await example_with_thread_persistence_in_memory() - await example_with_existing_thread_id() + await example_with_automatic_session_creation() + await example_with_session_persistence_in_memory() + await example_with_existing_session_id() if __name__ == "__main__": diff --git a/python/samples/02-agents/tools/function_tool_recover_from_failures.py b/python/samples/02-agents/tools/function_tool_recover_from_failures.py index f3ce9f1adc..9c506d1304 100644 --- a/python/samples/02-agents/tools/function_tool_recover_from_failures.py +++ b/python/samples/02-agents/tools/function_tool_recover_from_failures.py @@ -44,29 +44,28 @@ async def main(): instructions="Use the provided tools.", tools=[greet, safe_divide], ) - thread = agent.get_new_thread() + session = agent.create_session() print("=" * 60) print("Step 1: Call divide(10, 0) - tool raises exception") - response = await agent.run("Divide 10 by 0", thread=thread) + response = await agent.run("Divide 10 by 0", session=session) print(f"Response: {response.text}") print("=" * 60) print("Step 2: Call greet('Bob') - conversation can keep going.") - response = await agent.run("Greet Bob", thread=thread) + response = await agent.run("Greet Bob", session=session) print(f"Response: {response.text}") print("=" * 60) - print("Replay the conversation:") - assert thread.message_store - assert thread.message_store.list_messages - for idx, msg in enumerate(await thread.message_store.list_messages()): - if msg.text: - print(f"{idx + 1} {msg.author_name or msg.role}: {msg.text} ") - for content in msg.contents: - if content.type == "function_call": - print( - f"{idx + 1} {msg.author_name}: calling function: {content.name} with arguments: {content.arguments}" - ) - if content.type == "function_result": - print(f"{idx + 1} {msg.role}: {content.result if content.result else content.exception}") + # TODO: Use history providers to replay the conversation + # print("Replay the conversation:") + # for idx, msg in enumerate(messages): + # if msg.text: + # print(f"{idx + 1} {msg.author_name or msg.role}: {msg.text} ") + # for content in msg.contents: + # if content.type == "function_call": + # print( + # f"{idx + 1} {msg.author_name}: calling function: {content.name} with arguments: {content.arguments}" + # ) + # if content.type == "function_result": + # print(f"{idx + 1} {msg.role}: {content.result if content.result else content.exception}") """ diff --git a/python/samples/02-agents/tools/function_tool_with_approval_and_threads.py b/python/samples/02-agents/tools/function_tool_with_approval_and_threads.py index e3f442ecee..004a182876 100644 --- a/python/samples/02-agents/tools/function_tool_with_approval_and_threads.py +++ b/python/samples/02-agents/tools/function_tool_with_approval_and_threads.py @@ -7,11 +7,11 @@ from agent_framework.azure import AzureOpenAIChatClient """ -Tool Approvals with Threads +Tool Approvals with Sessions -This sample demonstrates using tool approvals with threads. -With threads, you don't need to manually pass previous messages - -the thread stores and retrieves them automatically. +This sample demonstrates using tool approvals with sessions. +With sessions, you don't need to manually pass previous messages - +the session stores and retrieves them automatically. """ @@ -25,8 +25,8 @@ def add_to_calendar( async def approval_example() -> None: - """Example showing approval with threads.""" - print("=== Tool Approval with Thread ===\n") + """Example showing approval with sessions.""" + print("=== Tool Approval with Session ===\n") agent = Agent( client=AzureOpenAIChatClient(), @@ -35,12 +35,12 @@ async def approval_example() -> None: tools=[add_to_calendar], ) - thread = agent.get_new_thread() + session = agent.create_session() # Step 1: Agent requests to call the tool query = "Add a dentist appointment on March 15th" print(f"User: {query}") - result = await agent.run(query, thread=thread) + result = await agent.run(query, session=session) # Check for approval requests if result.user_input_requests: @@ -55,14 +55,14 @@ async def approval_example() -> None: # Step 2: Send approval response approval_response = request.to_function_approval_response(approved=approved) - result = await agent.run(Message("user", [approval_response]), thread=thread) + result = await agent.run(Message("user", [approval_response]), session=session) print(f"Agent: {result}\n") async def rejection_example() -> None: - """Example showing rejection with threads.""" - print("=== Tool Rejection with Thread ===\n") + """Example showing rejection with sessions.""" + print("=== Tool Rejection with Session ===\n") agent = Agent( client=AzureOpenAIChatClient(), @@ -71,11 +71,11 @@ async def rejection_example() -> None: tools=[add_to_calendar], ) - thread = agent.get_new_thread() + session = agent.create_session() query = "Add a team meeting on December 20th" print(f"User: {query}") - result = await agent.run(query, thread=thread) + result = await agent.run(query, session=session) if result.user_input_requests: for request in result.user_input_requests: @@ -88,7 +88,7 @@ async def rejection_example() -> None: # Send rejection response rejection_response = request.to_function_approval_response(approved=False) - result = await agent.run(Message("user", [rejection_response]), thread=thread) + result = await agent.run(Message("user", [rejection_response]), session=session) print(f"Agent: {result}\n") diff --git a/python/samples/02-agents/tools/function_tool_with_max_exceptions.py b/python/samples/02-agents/tools/function_tool_with_max_exceptions.py index 7e60487704..89d883174c 100644 --- a/python/samples/02-agents/tools/function_tool_with_max_exceptions.py +++ b/python/samples/02-agents/tools/function_tool_with_max_exceptions.py @@ -36,31 +36,30 @@ async def main(): instructions="Use the provided tools.", tools=[safe_divide], ) - thread = agent.get_new_thread() + session = agent.create_session() print("=" * 60) print("Step 1: Call divide(10, 0) - tool raises exception") - response = await agent.run("Divide 10 by 0", thread=thread) + response = await agent.run("Divide 10 by 0", session=session) print(f"Response: {response.text}") print("=" * 60) print("Step 2: Call divide(100, 0) - will refuse to execute due to max_invocation_exceptions") - response = await agent.run("Divide 100 by 0", thread=thread) + response = await agent.run("Divide 100 by 0", session=session) print(f"Response: {response.text}") print("=" * 60) print(f"Number of tool calls attempted: {safe_divide.invocation_count}") print(f"Number of tool calls failed: {safe_divide.invocation_exception_count}") - print("Replay the conversation:") - assert thread.message_store - assert thread.message_store.list_messages - for idx, msg in enumerate(await thread.message_store.list_messages()): - if msg.text: - print(f"{idx + 1} {msg.author_name or msg.role}: {msg.text} ") - for content in msg.contents: - if content.type == "function_call": - print( - f"{idx + 1} {msg.author_name}: calling function: {content.name} with arguments: {content.arguments}" - ) - if content.type == "function_result": - print(f"{idx + 1} {msg.role}: {content.result if content.result else content.exception}") + # TODO: Use history providers to replay the conversation + # print("Replay the conversation:") + # for idx, msg in enumerate(messages): + # if msg.text: + # print(f"{idx + 1} {msg.author_name or msg.role}: {msg.text} ") + # for content in msg.contents: + # if content.type == "function_call": + # print( + # f"{idx + 1} {msg.author_name}: calling function: {content.name} with arguments: {content.arguments}" + # ) + # if content.type == "function_result": + # print(f"{idx + 1} {msg.role}: {content.result if content.result else content.exception}") """ diff --git a/python/samples/02-agents/tools/function_tool_with_max_invocations.py b/python/samples/02-agents/tools/function_tool_with_max_invocations.py index be9d37d807..c8bdc306c3 100644 --- a/python/samples/02-agents/tools/function_tool_with_max_invocations.py +++ b/python/samples/02-agents/tools/function_tool_with_max_invocations.py @@ -25,31 +25,30 @@ async def main(): instructions="Use the provided tools.", tools=[unicorn_function], ) - thread = agent.get_new_thread() + session = agent.create_session() print("=" * 60) print("Step 1: Call unicorn_function") - response = await agent.run("Call 5 unicorns!", thread=thread) + response = await agent.run("Call 5 unicorns!", session=session) print(f"Response: {response.text}") print("=" * 60) print("Step 2: Call unicorn_function again - will refuse to execute due to max_invocations") - response = await agent.run("Call 10 unicorns and use the function to do it.", thread=thread) + response = await agent.run("Call 10 unicorns and use the function to do it.", session=session) print(f"Response: {response.text}") print("=" * 60) print(f"Number of tool calls attempted: {unicorn_function.invocation_count}") print(f"Number of tool calls failed: {unicorn_function.invocation_exception_count}") - print("Replay the conversation:") - assert thread.message_store - assert thread.message_store.list_messages - for idx, msg in enumerate(await thread.message_store.list_messages()): - if msg.text: - print(f"{idx + 1} {msg.author_name or msg.role}: {msg.text} ") - for content in msg.contents: - if content.type == "function_call": - print( - f"{idx + 1} {msg.author_name}: calling function: {content.name} with arguments: {content.arguments}" - ) - if content.type == "function_result": - print(f"{idx + 1} {msg.role}: {content.result if content.result else content.exception}") + # TODO: Use history providers to replay the conversation + # print("Replay the conversation:") + # for idx, msg in enumerate(messages): + # if msg.text: + # print(f"{idx + 1} {msg.author_name or msg.role}: {msg.text} ") + # for content in msg.contents: + # if content.type == "function_call": + # print( + # f"{idx + 1} {msg.author_name}: calling function: {content.name} with arguments: {content.arguments}" + # ) + # if content.type == "function_result": + # print(f"{idx + 1} {msg.role}: {content.result if content.result else content.exception}") """ diff --git a/python/samples/02-agents/tools/function_tool_with_thread_injection.py b/python/samples/02-agents/tools/function_tool_with_thread_injection.py index b73212774b..afd2bac555 100644 --- a/python/samples/02-agents/tools/function_tool_with_thread_injection.py +++ b/python/samples/02-agents/tools/function_tool_with_thread_injection.py @@ -3,15 +3,15 @@ import asyncio from typing import Annotated, Any -from agent_framework import AgentThread, tool +from agent_framework import AgentSession, tool from agent_framework.openai import OpenAIChatClient from pydantic import Field """ -AI Function with Thread Injection Example +AI Function with Session Injection Example -This example demonstrates the behavior when passing 'thread' to agent.run() -and accessing that thread in AI function. +This example demonstrates the behavior when passing 'session' to agent.run() +and accessing that session in AI function. """ @@ -23,14 +23,11 @@ async def get_weather( **kwargs: Any, ) -> str: """Get the weather for a given location.""" - # Get thread object from kwargs - thread = kwargs.get("thread") - if thread and isinstance(thread, AgentThread): - if thread.message_store: - messages = await thread.message_store.list_messages() - print(f"Thread contains {len(messages)} messages.") - elif thread.service_thread_id: - print(f"Thread ID: {thread.service_thread_id}.") + # Get session object from kwargs + session = kwargs.get("session") + if session and isinstance(session, AgentSession): + if session.service_session_id: + print(f"Session ID: {session.service_session_id}.") return f"The weather in {location} is cloudy." @@ -40,13 +37,13 @@ async def main() -> None: name="WeatherAgent", instructions="You are a helpful weather assistant.", tools=[get_weather] ) - # Create a thread - thread = agent.get_new_thread() + # Create a session + session = agent.create_session() - # Run the agent with the thread - print(f"Agent: {await agent.run('What is the weather in London?', thread=thread)}") - print(f"Agent: {await agent.run('What is the weather in Amsterdam?', thread=thread)}") - print(f"Agent: {await agent.run('What cities did I ask about?', thread=thread)}") + # Run the agent with the session + print(f"Agent: {await agent.run('What is the weather in London?', session=session)}") + print(f"Agent: {await agent.run('What is the weather in Amsterdam?', session=session)}") + print(f"Agent: {await agent.run('What cities did I ask about?', session=session)}") if __name__ == "__main__": diff --git a/python/samples/03-workflows/agents/azure_ai_agents_with_shared_thread.py b/python/samples/03-workflows/agents/azure_ai_agents_with_shared_thread.py index 3492be6474..988d3f539f 100644 --- a/python/samples/03-workflows/agents/azure_ai_agents_with_shared_thread.py +++ b/python/samples/03-workflows/agents/azure_ai_agents_with_shared_thread.py @@ -7,7 +7,7 @@ AgentExecutor, AgentExecutorRequest, AgentExecutorResponse, - ChatMessageStore, + InMemoryHistoryProvider, WorkflowBuilder, WorkflowContext, WorkflowRunState, @@ -59,22 +59,23 @@ async def main() -> None: credential=AzureCliCredential(), ) + # set the same context provider, with the same source_id, for both agents to share the thread writer = client.as_agent( instructions=("You are a concise copywriter. Provide a single, punchy marketing sentence based on the prompt."), name="writer", + context_providers=[InMemoryHistoryProvider("memory")], ) reviewer = client.as_agent( instructions=("You are a thoughtful reviewer. Give brief feedback on the previous assistant message."), name="reviewer", + context_providers=[InMemoryHistoryProvider("memory")], ) - shared_thread = writer.get_new_thread() - # Set the message store to store messages in memory. - shared_thread.message_store = ChatMessageStore() - - writer_executor = AgentExecutor(writer, agent_thread=shared_thread) - reviewer_executor = AgentExecutor(reviewer, agent_thread=shared_thread) + # Create the shared session + shared_session = writer.create_session() + writer_executor = AgentExecutor(writer, session=shared_session) + reviewer_executor = AgentExecutor(reviewer, session=shared_session) workflow = ( WorkflowBuilder(start_executor=writer_executor) @@ -88,13 +89,15 @@ async def main() -> None: # Setting store=False to avoid storing messages in the service for this example. options={"store": False}, ) + # The final state should be IDLE since the workflow no longer has messages to # process after the reviewer agent responds. assert result.get_final_state() == WorkflowRunState.IDLE - # The shared thread now contains the conversation between the writer and reviewer. Print it out. - print("=== Shared Thread Conversation ===") - for message in shared_thread.message_store.messages: + # The shared session now contains the conversation between the writer and reviewer. Print it out. + print("=== Shared Session Conversation ===") + memory_state = shared_session.state.get("memory", {}) + for message in memory_state.get("messages", []): print(f"{message.author_name or message.role}: {message.text}") diff --git a/python/samples/03-workflows/agents/workflow_as_agent_with_thread.py b/python/samples/03-workflows/agents/workflow_as_agent_with_thread.py index 7e56f6618a..6a8716ce4c 100644 --- a/python/samples/03-workflows/agents/workflow_as_agent_with_thread.py +++ b/python/samples/03-workflows/agents/workflow_as_agent_with_thread.py @@ -3,17 +3,17 @@ import asyncio import os -from agent_framework import AgentThread, ChatMessageStore +from agent_framework import AgentSession from agent_framework.azure import AzureOpenAIResponsesClient from agent_framework.orchestrations import SequentialBuilder from azure.identity import AzureCliCredential """ -Sample: Workflow as Agent with Thread Conversation History and Checkpointing +Sample: Workflow as Agent with Session Conversation History and Checkpointing -This sample demonstrates how to use AgentThread with a workflow wrapped as an agent +This sample demonstrates how to use AgentSession with a workflow wrapped as an agent to maintain conversation history across multiple invocations. When using as_agent(), -the thread's message store history is included in each workflow run, enabling +the session's history is included in each workflow run, enabling the workflow participants to reference prior conversation context. It also demonstrates how to enable checkpointing for workflow execution state @@ -21,8 +21,8 @@ Key concepts: - Workflows can be wrapped as agents using workflow.as_agent() -- AgentThread with ChatMessageStore preserves conversation history -- Each call to agent.run() includes thread history + new message +- AgentSession preserves conversation history +- Each call to agent.run() includes session history + new message - Participants in the workflow see the full conversation context - checkpoint_storage parameter enables workflow state persistence @@ -68,19 +68,18 @@ async def main() -> None: # Wrap the workflow as an agent agent = workflow.as_agent(name="ConversationalWorkflowAgent") - # Create a thread with a ChatMessageStore to maintain history - message_store = ChatMessageStore() - thread = AgentThread(message_store=message_store) + # Create a session to maintain history + session = agent.create_session() print("=" * 60) - print("Workflow as Agent with Thread - Multi-turn Conversation") + print("Workflow as Agent with Session - Multi-turn Conversation") print("=" * 60) # First turn: Introduce a topic query1 = "My name is Alex and I'm learning about machine learning." print(f"\n[Turn 1] User: {query1}") - response1 = await agent.run(query1, thread=thread) + response1 = await agent.run(query1, session=session) if response1.messages: for msg in response1.messages: speaker = msg.author_name or msg.role @@ -90,7 +89,7 @@ async def main() -> None: query2 = "What was my name again, and what am I learning about?" print(f"\n[Turn 2] User: {query2}") - response2 = await agent.run(query2, thread=thread) + response2 = await agent.run(query2, session=session) if response2.messages: for msg in response2.messages: speaker = msg.author_name or msg.role @@ -100,7 +99,7 @@ async def main() -> None: query3 = "Can you suggest a good first project for me to try?" print(f"\n[Turn 3] User: {query3}") - response3 = await agent.run(query3, thread=thread) + response3 = await agent.run(query3, session=session) if response3.messages: for msg in response3.messages: speaker = msg.author_name or msg.role @@ -108,20 +107,20 @@ async def main() -> None: # Show the accumulated conversation history print("\n" + "=" * 60) - print("Full Thread History") + print("Full Session History") print("=" * 60) - if thread.message_store: - history = await thread.message_store.list_messages() - for i, msg in enumerate(history, start=1): - role = msg.role if hasattr(msg.role, "value") else str(msg.role) - speaker = msg.author_name or role - text_preview = msg.text[:80] + "..." if len(msg.text) > 80 else msg.text - print(f"{i:02d}. [{speaker}]: {text_preview}") + memory_state = session.state.get("memory", {}) + history = memory_state.get("messages", []) + for i, msg in enumerate(history, start=1): + role = msg.role if hasattr(msg.role, "value") else str(msg.role) + speaker = msg.author_name or role + text_preview = msg.text[:80] + "..." if len(msg.text) > 80 else msg.text + print(f"{i:02d}. [{speaker}]: {text_preview}") -async def demonstrate_thread_serialization() -> None: +async def demonstrate_session_serialization() -> None: """ - Demonstrates serializing and resuming a thread with a workflow agent. + Demonstrates serializing and resuming a session with a workflow agent. This shows how conversation history can be persisted and restored, enabling long-running conversational workflows. @@ -140,36 +139,35 @@ async def demonstrate_thread_serialization() -> None: workflow = SequentialBuilder(participants=[memory_assistant]).build() agent = workflow.as_agent(name="MemoryWorkflowAgent") - # Create initial thread and have a conversation - thread = AgentThread(message_store=ChatMessageStore()) + # Create initial session and have a conversation + session = agent.create_session() print("\n" + "=" * 60) - print("Thread Serialization Demo") + print("Session Serialization Demo") print("=" * 60) # First interaction query = "Remember this: the secret code is ALPHA-7." print(f"\n[Session 1] User: {query}") - response = await agent.run(query, thread=thread) + response = await agent.run(query, session=session) if response.messages: print(f"[assistant]: {response.messages[0].text}") - # Serialize thread state (could be saved to database/file) - serialized_state = await thread.serialize() - print("\n[Serialized thread state for persistence]") + # Serialize session state (could be saved to database/file) + serialized_state = session.to_dict() + print("\n[Serialized session state for persistence]") - # Simulate a new session by creating a new thread from serialized state - restored_thread = AgentThread(message_store=ChatMessageStore()) - await restored_thread.update_from_thread_state(serialized_state) + # Simulate a new session by creating a new session from serialized state + restored_session = AgentSession.from_dict(serialized_state) - # Continue conversation with restored thread + # Continue conversation with restored session query = "What was the secret code I told you?" print(f"\n[Session 2 - Restored] User: {query}") - response = await agent.run(query, thread=restored_thread) + response = await agent.run(query, session=restored_session) if response.messages: print(f"[assistant]: {response.messages[0].text}") if __name__ == "__main__": asyncio.run(main()) - asyncio.run(demonstrate_thread_serialization()) + asyncio.run(demonstrate_session_serialization()) diff --git a/python/samples/03-workflows/checkpoint/workflow_as_agent_checkpoint.py b/python/samples/03-workflows/checkpoint/workflow_as_agent_checkpoint.py index 2e0362dd73..4b0b7e88b8 100644 --- a/python/samples/03-workflows/checkpoint/workflow_as_agent_checkpoint.py +++ b/python/samples/03-workflows/checkpoint/workflow_as_agent_checkpoint.py @@ -15,9 +15,9 @@ - How to resume a workflow-as-agent from a checkpoint Key concepts: -- Thread (AgentThread): Maintains conversation history across agent invocations +- Thread (AgentSession): Maintains conversation history across agent invocations - Checkpoint: Persists workflow execution state for pause/resume capability -- These are complementary: threads track conversation, checkpoints track workflow state +- These are complementary: sessions track conversation, checkpoints track workflow state Prerequisites: - AZURE_AI_PROJECT_ENDPOINT must be your Azure AI Foundry Agent Service (V2) project endpoint. @@ -28,8 +28,7 @@ import os from agent_framework import ( - AgentThread, - ChatMessageStore, + AgentSession, InMemoryCheckpointStorage, ) from agent_framework.azure import AzureOpenAIResponsesClient @@ -102,21 +101,21 @@ async def checkpointing_with_thread() -> None: workflow = SequentialBuilder(participants=[assistant]).build() agent = workflow.as_agent(name="MemoryAgent") - # Create both thread (for conversation) and checkpoint storage (for workflow state) - thread = AgentThread(message_store=ChatMessageStore()) + # Create both session (for conversation) and checkpoint storage (for workflow state) + session = agent.create_session() checkpoint_storage = InMemoryCheckpointStorage() # First turn query1 = "My favorite color is blue. Remember that." print(f"\n[Turn 1] User: {query1}") - response1 = await agent.run(query1, thread=thread, checkpoint_storage=checkpoint_storage) + response1 = await agent.run(query1, session=session, checkpoint_storage=checkpoint_storage) if response1.messages: print(f"[assistant]: {response1.messages[0].text}") - # Second turn - agent should remember from thread history + # Second turn - agent should remember from session history query2 = "What's my favorite color?" print(f"\n[Turn 2] User: {query2}") - response2 = await agent.run(query2, thread=thread, checkpoint_storage=checkpoint_storage) + response2 = await agent.run(query2, session=session, checkpoint_storage=checkpoint_storage) if response2.messages: print(f"[assistant]: {response2.messages[0].text}") @@ -124,9 +123,9 @@ async def checkpointing_with_thread() -> None: checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=workflow.name) print(f"\nTotal checkpoints across both turns: {len(checkpoints)}") - if thread.message_store: - history = await thread.message_store.list_messages() - print(f"Messages in thread history: {len(history)}") + memory_state = session.state.get("memory", {}) + history = memory_state.get("messages", []) + print(f"Messages in session history: {len(history)}") async def streaming_with_checkpoints() -> None: diff --git a/python/samples/04-hosting/azure_functions/04_single_agent_orchestration_chaining/function_app.py b/python/samples/04-hosting/azure_functions/04_single_agent_orchestration_chaining/function_app.py index 33ccc5319f..0b6f97f87a 100644 --- a/python/samples/04-hosting/azure_functions/04_single_agent_orchestration_chaining/function_app.py +++ b/python/samples/04-hosting/azure_functions/04_single_agent_orchestration_chaining/function_app.py @@ -5,7 +5,7 @@ Components used in this sample: - AzureOpenAIChatClient to construct the writer agent hosted by Agent Framework. - AgentFunctionApp to surface HTTP and orchestration triggers via the Azure Functions extension. -- Durable Functions orchestration to run sequential agent invocations on the same conversation thread. +- Durable Functions orchestration to run sequential agent invocations on the same conversation session. Prerequisites: configure `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_CHAT_DEPLOYMENT_NAME`, and either `AZURE_OPENAI_API_KEY` or authenticate with Azure CLI before starting the Functions host.""" @@ -45,17 +45,17 @@ def _create_writer_agent() -> Any: app = AgentFunctionApp(agents=[_create_writer_agent()], enable_health_check=True) -# 4. Orchestration that runs the agent sequentially on a shared thread for chaining behaviour. +# 4. Orchestration that runs the agent sequentially on a shared session for chaining behaviour. @app.orchestration_trigger(context_name="context") def single_agent_orchestration(context: DurableOrchestrationContext) -> Generator[Any, Any, str]: - """Run the writer agent twice on the same thread to mirror chaining behaviour.""" + """Run the writer agent twice on the same session to mirror chaining behaviour.""" writer = app.get_agent(context, WRITER_AGENT_NAME) - writer_thread = writer.get_new_thread() + writer_session = writer.create_session() initial = yield writer.run( messages="Write a concise inspirational sentence about learning.", - thread=writer_thread, + session=writer_session, ) improved_prompt = ( @@ -65,7 +65,7 @@ def single_agent_orchestration(context: DurableOrchestrationContext) -> Generato refined = yield writer.run( messages=improved_prompt, - thread=writer_thread, + session=writer_session, ) return refined.text diff --git a/python/samples/04-hosting/azure_functions/05_multi_agent_orchestration_concurrency/function_app.py b/python/samples/04-hosting/azure_functions/05_multi_agent_orchestration_concurrency/function_app.py index 0be448295d..148835033f 100644 --- a/python/samples/04-hosting/azure_functions/05_multi_agent_orchestration_concurrency/function_app.py +++ b/python/samples/04-hosting/azure_functions/05_multi_agent_orchestration_concurrency/function_app.py @@ -64,12 +64,12 @@ def multi_agent_concurrent_orchestration(context: DurableOrchestrationContext) - physicist = app.get_agent(context, PHYSICIST_AGENT_NAME) chemist = app.get_agent(context, CHEMIST_AGENT_NAME) - physicist_thread = physicist.get_new_thread() - chemist_thread = chemist.get_new_thread() + physicist_session = physicist.create_session() + chemist_session = chemist.create_session() # Create tasks from agent.run() calls - physicist_task = physicist.run(messages=str(prompt), thread=physicist_thread) - chemist_task = chemist.run(messages=str(prompt), thread=chemist_thread) + physicist_task = physicist.run(messages=str(prompt), session=physicist_session) + chemist_task = chemist.run(messages=str(prompt), session=chemist_session) # Execute both tasks concurrently using task_all task_results = yield context.task_all([physicist_task, chemist_task]) diff --git a/python/samples/04-hosting/azure_functions/06_multi_agent_orchestration_conditionals/function_app.py b/python/samples/04-hosting/azure_functions/06_multi_agent_orchestration_conditionals/function_app.py index 0dbfeefd5c..148c6eaad5 100644 --- a/python/samples/04-hosting/azure_functions/06_multi_agent_orchestration_conditionals/function_app.py +++ b/python/samples/04-hosting/azure_functions/06_multi_agent_orchestration_conditionals/function_app.py @@ -89,7 +89,7 @@ def spam_detection_orchestration(context: DurableOrchestrationContext) -> Genera spam_agent = app.get_agent(context, SPAM_AGENT_NAME) email_agent = app.get_agent(context, EMAIL_AGENT_NAME) - spam_thread = spam_agent.get_new_thread() + spam_session = spam_agent.create_session() spam_prompt = ( "Analyze this email for spam content and return a JSON response with 'is_spam' (boolean) " @@ -100,7 +100,7 @@ def spam_detection_orchestration(context: DurableOrchestrationContext) -> Genera spam_result_raw = yield spam_agent.run( messages=spam_prompt, - thread=spam_thread, + session=spam_session, options={"response_format": SpamDetectionResult}, ) @@ -113,7 +113,7 @@ def spam_detection_orchestration(context: DurableOrchestrationContext) -> Genera result = yield context.call_activity("handle_spam_email", spam_result.reason) # type: ignore[misc] return result - email_thread = email_agent.get_new_thread() + email_session = email_agent.create_session() email_prompt = ( "Draft a professional response to this email. Return a JSON response with a 'response' field " @@ -124,7 +124,7 @@ def spam_detection_orchestration(context: DurableOrchestrationContext) -> Genera email_result_raw = yield email_agent.run( messages=email_prompt, - thread=email_thread, + session=email_session, options={"response_format": EmailResponse}, ) diff --git a/python/samples/04-hosting/azure_functions/07_single_agent_orchestration_hitl/function_app.py b/python/samples/04-hosting/azure_functions/07_single_agent_orchestration_hitl/function_app.py index 931092c6cc..644ed9ed23 100644 --- a/python/samples/04-hosting/azure_functions/07_single_agent_orchestration_hitl/function_app.py +++ b/python/samples/04-hosting/azure_functions/07_single_agent_orchestration_hitl/function_app.py @@ -93,13 +93,13 @@ def content_generation_hitl_orchestration(context: DurableOrchestrationContext) raise ValueError(f"Invalid content generation input: {exc}") from exc writer = app.get_agent(context, WRITER_AGENT_NAME) - writer_thread = writer.get_new_thread() + writer_session = writer.create_session() context.set_custom_status(f"Starting content generation for topic: {payload.topic}") initial_raw = yield writer.run( messages=f"Write a short article about '{payload.topic}'.", - thread=writer_thread, + session=writer_session, options={"response_format": GeneratedContent}, ) @@ -150,7 +150,7 @@ def content_generation_hitl_orchestration(context: DurableOrchestrationContext) ) rewritten_raw = yield writer.run( messages=rewrite_prompt, - thread=writer_thread, + session=writer_session, options={"response_format": GeneratedContent}, ) diff --git a/python/samples/04-hosting/durabletask/01_single_agent/client.py b/python/samples/04-hosting/durabletask/01_single_agent/client.py index d88c9e857f..7940d0421c 100644 --- a/python/samples/04-hosting/durabletask/01_single_agent/client.py +++ b/python/samples/04-hosting/durabletask/01_single_agent/client.py @@ -69,9 +69,9 @@ def run_client(agent_client: DurableAIAgentClient) -> None: logger.debug("Getting reference to Joker agent...") joker = agent_client.get_agent("Joker") - # Create a new thread for the conversation - thread = joker.get_new_thread() - logger.debug(f"Thread ID: {thread.session_id}") + # Create a new session for the conversation + session = joker.create_session() + logger.debug(f"Session ID: {session.session_id}") logger.info("Start chatting with the Joker agent! (Type 'exit' to quit)") # Interactive conversation loop @@ -94,7 +94,7 @@ def run_client(agent_client: DurableAIAgentClient) -> None: # Send message to agent and get response try: - response = joker.run(user_message, thread=thread) + response = joker.run(user_message, session=session) logger.info(f"Joker: {response.text} \n") except Exception as e: logger.error(f"Error getting response: {e}") diff --git a/python/samples/04-hosting/durabletask/02_multi_agent/client.py b/python/samples/04-hosting/durabletask/02_multi_agent/client.py index 4586186408..ee9f0e7ab6 100644 --- a/python/samples/04-hosting/durabletask/02_multi_agent/client.py +++ b/python/samples/04-hosting/durabletask/02_multi_agent/client.py @@ -70,30 +70,30 @@ def run_client(agent_client: DurableAIAgentClient) -> None: # Get reference to WeatherAgent weather_agent = agent_client.get_agent("WeatherAgent") - weather_thread = weather_agent.get_new_thread() + weather_session = weather_agent.create_session() - logger.debug(f"Created weather conversation thread: {weather_thread.session_id}") + logger.debug(f"Created weather conversation session: {weather_session.session_id}") # Test WeatherAgent weather_message = "What is the weather in Seattle?" logger.info(f"User: {weather_message}") - weather_response = weather_agent.run(weather_message, thread=weather_thread) + weather_response = weather_agent.run(weather_message, session=weather_session) logger.info(f"WeatherAgent: {weather_response.text} \n") logger.debug("Testing MathAgent") # Get reference to MathAgent math_agent = agent_client.get_agent("MathAgent") - math_thread = math_agent.get_new_thread() + math_session = math_agent.create_session() - logger.debug(f"Created math conversation thread: {math_thread.session_id}") + logger.debug(f"Created math conversation session: {math_session.session_id}") # Test MathAgent math_message = "Calculate a 20% tip on a $50 bill" logger.info(f"User: {math_message}") - math_response = math_agent.run(math_message, thread=math_thread) + math_response = math_agent.run(math_message, session=math_session) logger.info(f"MathAgent: {math_response.text} \n") logger.debug("Both agents completed successfully!") diff --git a/python/samples/04-hosting/durabletask/03_single_agent_streaming/client.py b/python/samples/04-hosting/durabletask/03_single_agent_streaming/client.py index c65b27b2a9..ab0d82ff41 100644 --- a/python/samples/04-hosting/durabletask/03_single_agent_streaming/client.py +++ b/python/samples/04-hosting/durabletask/03_single_agent_streaming/client.py @@ -140,14 +140,14 @@ def run_client(agent_client: DurableAIAgentClient) -> None: logger.debug("Getting reference to TravelPlanner agent...") travel_planner = agent_client.get_agent("TravelPlanner") - # Create a new thread for the conversation - thread = travel_planner.get_new_thread() - if not thread.session_id: - logger.error("Failed to create a new thread with session ID!") + # Create a new session for the conversation + session = travel_planner.create_session() + if not session.session_id: + logger.error("Failed to create a new session with session ID!") return - key = thread.session_id.key - logger.info(f"Thread ID: {key}") + key = session.session_id.key + logger.info(f"Session ID: {key}") # Get user input print("\nEnter your travel planning request:") @@ -164,7 +164,7 @@ def run_client(agent_client: DurableAIAgentClient) -> None: # Start the agent run with wait_for_response=False for non-blocking execution # This signals the agent to start processing without waiting for completion # The agent will execute in the background and write chunks to Redis - travel_planner.run(user_message, thread=thread, options={"wait_for_response": False}) + travel_planner.run(user_message, session=session, options={"wait_for_response": False}) # Stream the response from Redis # This demonstrates that the client can stream from Redis while diff --git a/python/samples/04-hosting/durabletask/04_single_agent_orchestration_chaining/worker.py b/python/samples/04-hosting/durabletask/04_single_agent_orchestration_chaining/worker.py index 581c95a06a..ecc44a8959 100644 --- a/python/samples/04-hosting/durabletask/04_single_agent_orchestration_chaining/worker.py +++ b/python/samples/04-hosting/durabletask/04_single_agent_orchestration_chaining/worker.py @@ -87,17 +87,17 @@ def single_agent_chaining_orchestration( # Get the writer agent using the agent context writer = agent_context.get_agent(WRITER_AGENT_NAME) - # Create a new thread for the conversation - this will be shared across both runs - writer_thread = writer.get_new_thread() + # Create a new session for the conversation - this will be shared across both runs + writer_session = writer.create_session() - logger.debug(f"[Orchestration] Created thread: {writer_thread.session_id}") + logger.debug(f"[Orchestration] Created session: {writer_session.session_id}") prompt = "Write a concise inspirational sentence about learning." # First run: Generate an initial inspirational sentence logger.info("[Orchestration] First agent run: Generating initial sentence about: %s", prompt) initial_response = yield writer.run( messages=prompt, - thread=writer_thread, + session=writer_session, ) logger.info(f"[Orchestration] Initial response: {initial_response.text}") @@ -110,7 +110,7 @@ def single_agent_chaining_orchestration( logger.info("[Orchestration] Second agent run: Refining the sentence: %s", improved_prompt) refined_response = yield writer.run( messages=improved_prompt, - thread=writer_thread, + session=writer_session, ) logger.info(f"[Orchestration] Refined response: {refined_response.text}") diff --git a/python/samples/04-hosting/durabletask/05_multi_agent_orchestration_concurrency/worker.py b/python/samples/04-hosting/durabletask/05_multi_agent_orchestration_concurrency/worker.py index 67861cc8c9..716355ec8b 100644 --- a/python/samples/04-hosting/durabletask/05_multi_agent_orchestration_concurrency/worker.py +++ b/python/samples/04-hosting/durabletask/05_multi_agent_orchestration_concurrency/worker.py @@ -80,15 +80,15 @@ def multi_agent_concurrent_orchestration(context: OrchestrationContext, prompt: physicist = agent_context.get_agent(PHYSICIST_AGENT_NAME) chemist = agent_context.get_agent(CHEMIST_AGENT_NAME) - # Create separate threads for each agent - physicist_thread = physicist.get_new_thread() - chemist_thread = chemist.get_new_thread() + # Create separate sessions for each agent + physicist_session = physicist.create_session() + chemist_session = chemist.create_session() - logger.debug(f"[Orchestration] Created threads - Physicist: {physicist_thread.session_id}, Chemist: {chemist_thread.session_id}") + logger.debug(f"[Orchestration] Created sessions - Physicist: {physicist_session.session_id}, Chemist: {chemist_session.session_id}") # Create tasks from agent.run() calls - these return DurableAgentTask instances - physicist_task = physicist.run(messages=str(prompt), thread=physicist_thread) - chemist_task = chemist.run(messages=str(prompt), thread=chemist_thread) + physicist_task = physicist.run(messages=str(prompt), session=physicist_session) + chemist_task = chemist.run(messages=str(prompt), session=chemist_session) logger.debug("[Orchestration] Created agent tasks, executing concurrently...") diff --git a/python/samples/04-hosting/durabletask/07_single_agent_orchestration_hitl/worker.py b/python/samples/04-hosting/durabletask/07_single_agent_orchestration_hitl/worker.py index da86d869a0..d90973ef1d 100644 --- a/python/samples/04-hosting/durabletask/07_single_agent_orchestration_hitl/worker.py +++ b/python/samples/04-hosting/durabletask/07_single_agent_orchestration_hitl/worker.py @@ -150,16 +150,16 @@ def content_generation_hitl_orchestration( # Get the writer agent writer = agent_context.get_agent(WRITER_AGENT_NAME) - writer_thread = writer.get_new_thread() + writer_session = writer.create_session() - logger.info(f"ThreadID: {writer_thread.session_id}") + logger.info(f"SessionID: {writer_session.session_id}") # Generate initial content logger.info("[Orchestration] Generating initial content...") initial_response: AgentResponse = yield writer.run( messages=f"Write a short article about '{payload.topic}'.", - thread=writer_thread, + session=writer_session, options={"response_format": GeneratedContent}, ) content = cast(GeneratedContent, initial_response.value) @@ -251,11 +251,11 @@ def content_generation_hitl_orchestration( logger.debug("[Orchestration] Regenerating content with feedback...") - logger.warning(f"Regenerating with ThreadID: {writer_thread.session_id}") + logger.warning(f"Regenerating with SessionID: {writer_session.session_id}") rewrite_response: AgentResponse = yield writer.run( messages=rewrite_prompt, - thread=writer_thread, + session=writer_session, options={"response_format": GeneratedContent}, ) rewritten_content = cast(GeneratedContent, rewrite_response.value) diff --git a/python/samples/05-end-to-end/hosted_agents/agent_with_text_search_rag/main.py b/python/samples/05-end-to-end/hosted_agents/agent_with_text_search_rag/main.py index 8e6c77d712..5c28917b51 100644 --- a/python/samples/05-end-to-end/hosted_agents/agent_with_text_search_rag/main.py +++ b/python/samples/05-end-to-end/hosted_agents/agent_with_text_search_rag/main.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from typing import Any -from agent_framework import Context, ContextProvider, Message +from agent_framework import AgentSession, BaseContextProvider, Message, SessionContext from agent_framework.azure import AzureOpenAIChatClient from azure.ai.agentserver.agentframework import from_agent_framework # pyright: ignore[reportUnknownVariableType] from azure.identity import DefaultAzureCredential @@ -24,19 +24,30 @@ class TextSearchResult: text: str -class TextSearchContextProvider(ContextProvider): +class TextSearchContextProvider(BaseContextProvider): """A simple context provider that simulates text search results based on keywords in the user's message.""" - def _get_most_recent_message(self, messages: Message | MutableSequence[Message]) -> Message: + def __init__(self): + super().__init__("text-search") + + def _get_most_recent_message(self, messages: list[Message]) -> Message: """Helper method to extract the most recent message from the input.""" - if isinstance(messages, Message): - return messages if messages: return messages[-1] raise ValueError("No messages provided") @override - async def invoking(self, messages: Message | MutableSequence[Message], **kwargs: Any) -> Context: + async def before_run( + self, + *, + agent: Any, + session: AgentSession | None, + context: SessionContext, + state: dict[str, Any], + ) -> None: + messages = context.get_messages() + if not messages: + return message = self._get_most_recent_message(messages) query = message.text.lower() @@ -80,14 +91,15 @@ async def invoking(self, messages: Message | MutableSequence[Message], **kwargs: ) if not results: - return Context() + return - return Context( - messages=[ + context.extend_messages( + self.source_id, + [ Message( role="user", text="\n\n".join(json.dumps(result.__dict__, indent=2) for result in results) ) - ] + ], ) @@ -99,7 +111,7 @@ def main(): "You are a helpful support specialist for Contoso Outdoors. " "Answer questions using the provided context and cite the source document when available." ), - context_provider=TextSearchContextProvider(), + context_providers=[TextSearchContextProvider()], ) # Run the agent as a hosted agent diff --git a/python/samples/autogen-migration/single_agent/03_assistant_agent_thread_and_stream.py b/python/samples/autogen-migration/single_agent/03_assistant_agent_thread_and_stream.py index 73fb0f3c62..5e1de9d873 100644 --- a/python/samples/autogen-migration/single_agent/03_assistant_agent_thread_and_stream.py +++ b/python/samples/autogen-migration/single_agent/03_assistant_agent_thread_and_stream.py @@ -46,7 +46,7 @@ async def run_autogen() -> None: async def run_agent_framework() -> None: - """Agent Framework agent with explicit thread and streaming.""" + """Agent Framework agent with explicit session and streaming.""" from agent_framework.openai import OpenAIChatClient client = OpenAIChatClient(model_id="gpt-4.1-mini") @@ -55,22 +55,22 @@ async def run_agent_framework() -> None: instructions="You are a helpful math tutor.", ) - print("[Agent Framework] Conversation with thread:") - # Create a thread to maintain state - thread = agent.get_new_thread() + print("[Agent Framework] Conversation with session:") + # Create a session to maintain state + session = agent.create_session() - # First turn - pass thread to maintain history - result1 = await agent.run("What is 15 + 27?", thread=thread) + # First turn - pass session to maintain history + result1 = await agent.run("What is 15 + 27?", session=session) print(f" Q1: {result1.text}") - # Second turn - agent remembers context via thread - result2 = await agent.run("What about that number times 2?", thread=thread) + # Second turn - agent remembers context via session + result2 = await agent.run("What about that number times 2?", session=session) print(f" Q2: {result2.text}") print("\n[Agent Framework] Streaming response:") # Stream response print(" ", end="") - async for chunk in agent.run("Count from 1 to 5", thread=thread, stream=True): + async for chunk in agent.run("Count from 1 to 5", session=session, stream=True): if chunk.text: print(chunk.text, end="", flush=True) print() diff --git a/python/samples/semantic-kernel-migration/azure_ai_agent/03_azure_ai_agent_threads_and_followups.py b/python/samples/semantic-kernel-migration/azure_ai_agent/03_azure_ai_agent_threads_and_followups.py index ae0b28e37d..ecd4a2b0b4 100644 --- a/python/samples/semantic-kernel-migration/azure_ai_agent/03_azure_ai_agent_threads_and_followups.py +++ b/python/samples/semantic-kernel-migration/azure_ai_agent/03_azure_ai_agent_threads_and_followups.py @@ -52,19 +52,19 @@ async def run_agent_framework() -> None: instructions="Track follow-up questions within the same thread.", ) as agent, ): - thread = agent.get_new_thread() - # AF threads are explicit and can be serialized for external storage. - first = await agent.run("Outline the onboarding checklist.", thread=thread) + session = agent.create_session() + # AF sessions are explicit and can be serialized for external storage. + first = await agent.run("Outline the onboarding checklist.", session=session) print("[AF][turn1]", first.text) second = await agent.run( "Highlight the items that require legal review.", - thread=thread, + session=session, ) print("[AF][turn2]", second.text) - serialized = await thread.serialize() - print("[AF][thread-json]", serialized) + serialized = session.to_dict() + print("[AF][session-json]", serialized) async def main() -> None: diff --git a/python/samples/semantic-kernel-migration/chat_completion/02_chat_completion_with_tool.py b/python/samples/semantic-kernel-migration/chat_completion/02_chat_completion_with_tool.py index 2bf7266018..1267027364 100644 --- a/python/samples/semantic-kernel-migration/chat_completion/02_chat_completion_with_tool.py +++ b/python/samples/semantic-kernel-migration/chat_completion/02_chat_completion_with_tool.py @@ -56,10 +56,10 @@ async def specials() -> str: instructions="Answer menu questions accurately.", tools=[specials], ) - thread = chat_agent.get_new_thread() + session = chat_agent.create_session() reply = await chat_agent.run( "What soup can I order today?", - thread=thread, + session=session, tool_choice="auto", ) print("[AF]", reply.text) diff --git a/python/samples/semantic-kernel-migration/chat_completion/03_chat_completion_thread_and_stream.py b/python/samples/semantic-kernel-migration/chat_completion/03_chat_completion_thread_and_stream.py index d357c2f957..78021a81ac 100644 --- a/python/samples/semantic-kernel-migration/chat_completion/03_chat_completion_thread_and_stream.py +++ b/python/samples/semantic-kernel-migration/chat_completion/03_chat_completion_thread_and_stream.py @@ -48,23 +48,23 @@ async def run_semantic_kernel() -> None: async def run_agent_framework() -> None: from agent_framework.openai import OpenAIChatClient - # AF thread objects are requested explicitly from the agent. + # AF session objects are requested explicitly from the agent. chat_agent = OpenAIChatClient().as_agent( name="Writer", instructions="Keep answers short and friendly.", ) - thread = chat_agent.get_new_thread() + session = chat_agent.create_session() first = await chat_agent.run( "Suggest a catchy headline for our product launch.", - thread=thread, + session=session, ) print("[AF]", first.text) print("[AF][stream]", end=" ") async for chunk in chat_agent.run( "Draft a 2 sentence blurb.", - thread=thread, + session=session, stream=True, ): if chunk.text: diff --git a/python/samples/semantic-kernel-migration/openai_assistant/01_basic_openai_assistant.py b/python/samples/semantic-kernel-migration/openai_assistant/01_basic_openai_assistant.py index 34709fbaf1..fbdd163fa7 100644 --- a/python/samples/semantic-kernel-migration/openai_assistant/01_basic_openai_assistant.py +++ b/python/samples/semantic-kernel-migration/openai_assistant/01_basic_openai_assistant.py @@ -50,7 +50,7 @@ async def run_agent_framework() -> None: print("[AF]", reply.text) follow_up = await assistant_agent.run( "How many residents live there?", - thread=assistant_agent.get_new_thread(), + session=assistant_agent.create_session(), ) print("[AF][follow-up]", follow_up.text) From 7d71e8f41c2f0d9f5c8acead7b221025bc2e567e Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Feb 2026 12:37:44 +0100 Subject: [PATCH 04/28] fix: update remaining ag-ui references (client docstring, getting_started sample) --- .../ag-ui/agent_framework_ag_ui/_client.py | 6 +- .../getting_started/client_with_agent.py | 79 ++++--------------- 2 files changed, 20 insertions(+), 65 deletions(-) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_client.py b/python/packages/ag-ui/agent_framework_ag_ui/_client.py index f1dba1b078..373e6321bb 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_client.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_client.py @@ -171,11 +171,11 @@ class AGUIChatClient( client = AGUIChatClient(endpoint="http://localhost:8888/") agent = Agent(name="assistant", client=client) - thread = await agent.get_new_thread() + session = agent.create_session() # Agent automatically maintains history and sends full context - response = await agent.run("Hello!", thread=thread) - response2 = await agent.run("How are you?", thread=thread) + response = await agent.run("Hello!", session=session) + response2 = await agent.run("How are you?", session=session) Streaming usage: diff --git a/python/packages/ag-ui/getting_started/client_with_agent.py b/python/packages/ag-ui/getting_started/client_with_agent.py index f0d7630294..11ba4c95ca 100644 --- a/python/packages/ag-ui/getting_started/client_with_agent.py +++ b/python/packages/ag-ui/getting_started/client_with_agent.py @@ -4,10 +4,10 @@ This demonstrates the HYBRID pattern matching .NET AGUIClient implementation: -1. AgentThread Pattern (like .NET): - - Create thread with agent.get_new_thread() - - Pass thread to agent.run(stream=True) on each turn - - Thread automatically maintains conversation history via message_store +1. AgentSession Pattern (like .NET): + - Create session with agent.create_session() + - Pass session to agent.run(stream=True) on each turn + - Session maintains conversation context via context providers 2. Hybrid Tool Execution: - AGUIChatClient uses function invocation mixin @@ -15,7 +15,7 @@ - Server may also have its own tools that execute server-side - Both work together: server LLM decides which tool to call, decorator handles client execution -This matches .NET pattern: thread maintains state, tools execute on appropriate side. +This matches .NET pattern: session maintains state, tools execute on appropriate side. """ from __future__ import annotations @@ -59,13 +59,13 @@ async def main(): This matches the .NET pattern from Program.cs where: - AIAgent agent = chatClient.CreateAIAgent(tools: [...]) - - AgentThread thread = agent.GetNewThread() - - RunStreamingAsync(messages, thread) + - AgentSession session = agent.CreateSession() + - RunStreamingAsync(messages, session) Python equivalent: - agent = Agent(client=AGUIChatClient(...), tools=[...]) - - thread = agent.get_new_thread() # Creates thread with message_store - - agent.run(message, stream=True, thread=thread) # Thread accumulates history + - session = agent.create_session() # Creates session + - agent.run(message, stream=True, session=session) # Session tracks context """ server_url = os.environ.get("AGUI_SERVER_URL", "http://127.0.0.1:5100/") @@ -74,7 +74,7 @@ async def main(): print("=" * 70) print(f"\nServer: {server_url}") print("\nThis example demonstrates:") - print(" 1. AgentThread maintains conversation state (like .NET)") + print(" 1. AgentSession maintains conversation state (like .NET)") print(" 2. Client-side tools execute locally via function invocation mixin") print(" 3. Server may have additional tools that execute server-side") print(" 4. HYBRID: Client and server tools work together simultaneously\n") @@ -90,8 +90,8 @@ async def main(): tools=[get_weather], ) - # Create a thread to maintain conversation state (like .NET AgentThread) - thread = agent.get_new_thread() + # Create a session to maintain conversation state (like .NET AgentSession) + session = agent.create_session() print("=" * 70) print("CONVERSATION WITH HISTORY") @@ -99,21 +99,21 @@ async def main(): # Turn 1: Introduce print("\nUser: My name is Alice and I live in Seattle\n") - async for chunk in agent.run("My name is Alice and I live in Seattle", stream=True, thread=thread): + async for chunk in agent.run("My name is Alice and I live in Seattle", stream=True, session=session): if chunk.text: print(chunk.text, end="", flush=True) print("\n") # Turn 2: Ask about name (tests history) print("User: What's my name?\n") - async for chunk in agent.run("What's my name?", stream=True, thread=thread): + async for chunk in agent.run("What's my name?", stream=True, session=session): if chunk.text: print(chunk.text, end="", flush=True) print("\n") # Turn 3: Ask about location (tests history) print("User: Where do I live?\n") - async for chunk in agent.run("Where do I live?", stream=True, thread=thread): + async for chunk in agent.run("Where do I live?", stream=True, session=session): if chunk.text: print(chunk.text, end="", flush=True) print("\n") @@ -123,7 +123,7 @@ async def main(): async for chunk in agent.run( "What's the weather forecast for today in Seattle?", stream=True, - thread=thread, + session=session, ): if chunk.text: print(chunk.text, end="", flush=True) @@ -131,56 +131,11 @@ async def main(): # Turn 5: Test server-side tool (get_time_zone is server-side only) print("User: What time zone is Seattle in?\n") - async for chunk in agent.run("What time zone is Seattle in?", stream=True, thread=thread): + async for chunk in agent.run("What time zone is Seattle in?", stream=True, session=session): if chunk.text: print(chunk.text, end="", flush=True) print("\n") - # Show thread state - if thread.message_store: - - def _preview_for_message(m) -> str: - # Prefer plain text when present - if getattr(m, "text", ""): - t = m.text - return (t[:60] + "...") if len(t) > 60 else t - # Build from contents when no direct text - parts: list[str] = [] - for c in getattr(m, "contents", []) or []: - content_type = getattr(c, "type", None) - if content_type == "function_call": - args = getattr(c, "arguments", None) - if isinstance(args, dict): - try: - import json as _json - - args_str = _json.dumps(args) - except Exception: - args_str = str(args) - else: - args_str = str(args or "{}") - parts.append(f"tool_call {getattr(c, 'name', '?')} {args_str}") - elif content_type == "function_result": - call_id = getattr(c, "call_id", "?") - result = getattr(c, "result", None) - parts.append(f"tool_result[{call_id}]: {str(result)[:40]}") - elif content_type == "text": - text = getattr(c, "text", None) - if text: - parts.append(text) - else: - typename = getattr(c, "type", c.__class__.__name__) - parts.append(f"<{typename}>") - preview = " | ".join(parts) if parts else "" - return (preview[:60] + "...") if len(preview) > 60 else preview - - messages = await thread.message_store.list_messages() - print(f"\n[THREAD STATE] {len(messages)} messages in thread's message_store") - for i, msg in enumerate(messages[-6:], 1): # Show last 6 - role = msg.role if hasattr(msg.role, "value") else str(msg.role) - text_preview = _preview_for_message(msg) - print(f" {i}. [{role}]: {text_preview}") - except ConnectionError as e: print(f"\n\033[91mConnection Error: {e}\033[0m") print("\nMake sure an AG-UI server is running at the specified endpoint.") From 7edb857a344e98ce564fef66ee7351ae3dac6b72 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Feb 2026 12:40:58 +0100 Subject: [PATCH 05/28] fix: make get_session service_session_id keyword-only to avoid confusion with session_id --- python/packages/core/agent_framework/_agents.py | 6 +++--- python/packages/core/tests/core/test_agents.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index d99f14eefb..7fa9842c82 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -217,7 +217,7 @@ def create_session(self, **kwargs): return AgentSession(**kwargs) - def get_session(self, service_session_id, **kwargs): + def get_session(self, *, service_session_id, **kwargs): from agent_framework import AgentSession return AgentSession(service_session_id=service_session_id, **kwargs) @@ -289,7 +289,7 @@ def create_session(self, **kwargs: Any) -> AgentSession: """Creates a new conversation session.""" ... - def get_session(self, service_session_id: str, **kwargs: Any) -> AgentSession: + def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession: """Gets or creates a session for a service-managed session ID.""" ... @@ -398,7 +398,7 @@ def create_session(self, *, session_id: str | None = None, **kwargs: Any) -> Age """ return AgentSession(session_id=session_id) - def get_session(self, service_session_id: str, *, session_id: str | None = None, **kwargs: Any) -> AgentSession: + def get_session(self, *, service_session_id: str, session_id: str | None = None, **kwargs: Any) -> AgentSession: """Get or create a session for a service-managed session ID. Args: diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index c2976c85bc..cd7b2c7aba 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -146,7 +146,7 @@ async def test_chat_client_agent_run_with_session(chat_client_base: SupportsChat client=chat_client_base, tools={"type": "code_interpreter"}, ) - session = agent.get_session("123") + session = agent.get_session(service_session_id="123") result = await agent.run("Hello", session=session) assert result.text == "test response" @@ -259,7 +259,7 @@ async def test_chat_agent_context_providers_after_run(chat_client_base: Supports agent = Agent(client=chat_client_base, context_providers=[mock_provider]) - session = agent.get_session("test-thread-id") + session = agent.get_session(service_session_id="test-thread-id") await agent.run("Hello", session=session) assert mock_provider.after_run_called @@ -345,7 +345,7 @@ async def test_chat_agent_context_providers_with_service_session_id(chat_client_ agent = Agent(client=chat_client_base, context_providers=[mock_provider]) # Use existing service-managed session - session = agent.get_session("existing-thread-id") + session = agent.get_session(service_session_id="existing-thread-id") await agent.run("Hello", session=session) # after_run should be called @@ -829,7 +829,7 @@ async def test_agent_get_session_with_service_session_id( """Test that get_session creates a session with service_session_id.""" agent = Agent(client=chat_client_base, tools=[tool_tool]) - session = agent.get_session("test-thread-123") + session = agent.get_session(service_session_id="test-thread-123") assert session is not None assert session.service_session_id == "test-thread-123" From c6c0d721288144c4768499798e7ceb2c2cf88187 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Feb 2026 12:42:39 +0100 Subject: [PATCH 06/28] refactor: rename _RunContext.thread_messages to session_messages --- python/packages/core/agent_framework/_agents.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 7fa9842c82..2ccb657b6d 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -157,7 +157,7 @@ class _RunContext(TypedDict): session: AgentSession | None session_context: SessionContext input_messages: list[Message] - thread_messages: list[Message] + session_messages: list[Message] agent_name: str chat_options: dict[str, Any] filtered_kwargs: dict[str, Any] @@ -850,7 +850,7 @@ async def _run_non_streaming() -> AgentResponse[Any]: kwargs=kwargs, ) response = await self.client.get_response( # type: ignore[call-overload] - messages=ctx["thread_messages"], + messages=ctx["session_messages"], stream=False, options=ctx["chat_options"], **ctx["filtered_kwargs"], @@ -918,7 +918,7 @@ async def _get_stream() -> ResponseStream[ChatResponseUpdate, ChatResponse]: ) ctx: _RunContext = ctx_holder["ctx"] # type: ignore[assignment] # Safe: we just assigned it return self.client.get_response( # type: ignore[call-overload, no-any-return] - messages=ctx["thread_messages"], + messages=ctx["session_messages"], stream=True, options=ctx["chat_options"], **ctx["filtered_kwargs"], @@ -1034,8 +1034,8 @@ async def _prepare_run_context( run_opts = {k: v for k, v in run_opts.items() if v is not None} co = _merge_options(chat_options, run_opts) - # Build thread_messages from session context: context messages + input messages - thread_messages: list[Message] = session_context.get_messages(include_input=True) + # Build session_messages from session context: context messages + input messages + session_messages: list[Message] = session_context.get_messages(include_input=True) # Ensure session is forwarded in kwargs for tool invocation finalize_kwargs = dict(kwargs) @@ -1047,7 +1047,7 @@ async def _prepare_run_context( "session": session, "session_context": session_context, "input_messages": input_messages, - "thread_messages": thread_messages, + "session_messages": session_messages, "agent_name": agent_name, "chat_options": co, "filtered_kwargs": filtered_kwargs, From c65254c6b3c1b2e1e9c6a4c613dceb1b7f32eb55 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Feb 2026 12:55:03 +0100 Subject: [PATCH 07/28] refactor: remove _threads.py, _memory.py, and old provider files; migrate devui to use plain message lists --- .../__init__.py | 3 +- .../_context_provider.py | 49 +- .../_search_provider.py | 991 ------------------ .../packages/core/agent_framework/__init__.py | 2 - .../packages/core/agent_framework/_memory.py | 181 ---- .../core/agent_framework/_serialization.py | 39 +- .../packages/core/agent_framework/_threads.py | 507 --------- .../packages/core/tests/core/test_memory.py | 136 --- .../packages/core/tests/core/test_threads.py | 600 ----------- .../agent_framework_devui/_conversations.py | 261 +++-- .../devui/tests/devui/test_conversations.py | 26 +- .../mem0/agent_framework_mem0/_provider.py | 239 ----- .../orchestrations/tests/test_handoff.py | 1 - .../_chat_message_store.py | 595 ----------- .../redis/agent_framework_redis/_provider.py | 595 ----------- 15 files changed, 193 insertions(+), 4032 deletions(-) delete mode 100644 python/packages/azure-ai-search/agent_framework_azure_ai_search/_search_provider.py delete mode 100644 python/packages/core/agent_framework/_memory.py delete mode 100644 python/packages/core/agent_framework/_threads.py delete mode 100644 python/packages/core/tests/core/test_memory.py delete mode 100644 python/packages/core/tests/core/test_threads.py delete mode 100644 python/packages/mem0/agent_framework_mem0/_provider.py delete mode 100644 python/packages/redis/agent_framework_redis/_chat_message_store.py delete mode 100644 python/packages/redis/agent_framework_redis/_provider.py diff --git a/python/packages/azure-ai-search/agent_framework_azure_ai_search/__init__.py b/python/packages/azure-ai-search/agent_framework_azure_ai_search/__init__.py index 7308f427c5..e8782e2117 100644 --- a/python/packages/azure-ai-search/agent_framework_azure_ai_search/__init__.py +++ b/python/packages/azure-ai-search/agent_framework_azure_ai_search/__init__.py @@ -2,8 +2,7 @@ import importlib.metadata -from ._context_provider import AzureAISearchContextProvider -from ._search_provider import AzureAISearchSettings +from ._context_provider import AzureAISearchContextProvider, AzureAISearchSettings try: __version__ = importlib.metadata.version(__name__) diff --git a/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py b/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py index 127359372b..0b0a24768d 100644 --- a/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py +++ b/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py @@ -14,6 +14,7 @@ from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Message from agent_framework._logging import get_logger +from agent_framework._pydantic import AFBaseSettings from agent_framework._sessions import AgentSession, BaseContextProvider, SessionContext from agent_framework._settings import load_settings from agent_framework.exceptions import ServiceInitializationError @@ -41,8 +42,54 @@ VectorizableTextQuery, VectorizedQuery, ) +from pydantic import SecretStr + + +class AzureAISearchSettings(AFBaseSettings): + """Settings for Azure AI Search Context Provider with auto-loading from environment. + + The settings are first loaded from environment variables with the prefix 'AZURE_SEARCH_'. + If the environment variables are not found, the settings can be loaded from a .env file. + + Keyword Args: + endpoint: Azure AI Search endpoint URL. + Can be set via environment variable AZURE_SEARCH_ENDPOINT. + index_name: Name of the search index. + Can be set via environment variable AZURE_SEARCH_INDEX_NAME. + knowledge_base_name: Name of an existing Knowledge Base (for agentic mode). + Can be set via environment variable AZURE_SEARCH_KNOWLEDGE_BASE_NAME. + api_key: API key for authentication (optional, use managed identity if not provided). + Can be set via environment variable AZURE_SEARCH_API_KEY. + env_file_path: If provided, the .env settings are read from this file path location. + env_file_encoding: The encoding of the .env file, defaults to 'utf-8'. + + Examples: + .. code-block:: python + + from agent_framework_aisearch import AzureAISearchSettings + + # Using environment variables + # Set AZURE_SEARCH_ENDPOINT=https://mysearch.search.windows.net + # Set AZURE_SEARCH_INDEX_NAME=my-index + settings = AzureAISearchSettings() + + # Or passing parameters directly + settings = AzureAISearchSettings( + endpoint="https://mysearch.search.windows.net", + index_name="my-index", + ) + + # Or loading from a .env file + settings = AzureAISearchSettings(env_file_path="path/to/.env") + """ + + env_prefix: ClassVar[str] = "AZURE_SEARCH_" + + endpoint: str | None = None + index_name: str | None = None + knowledge_base_name: str | None = None + api_key: SecretStr | None = None -from ._search_provider import AzureAISearchSettings if TYPE_CHECKING: from agent_framework._agents import SupportsAgentRun diff --git a/python/packages/azure-ai-search/agent_framework_azure_ai_search/_search_provider.py b/python/packages/azure-ai-search/agent_framework_azure_ai_search/_search_provider.py deleted file mode 100644 index 5e47b37b00..0000000000 --- a/python/packages/azure-ai-search/agent_framework_azure_ai_search/_search_provider.py +++ /dev/null @@ -1,991 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - - -from __future__ import annotations - -import sys -from collections.abc import Awaitable, Callable, MutableSequence -from typing import TYPE_CHECKING, Any, Literal - -from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Context, ContextProvider, Message -from agent_framework._logging import get_logger -from agent_framework._settings import SecretString, load_settings -from agent_framework.exceptions import ServiceInitializationError -from azure.core.credentials import AzureKeyCredential -from azure.core.credentials_async import AsyncTokenCredential -from azure.core.exceptions import ResourceNotFoundError -from azure.search.documents.aio import SearchClient -from azure.search.documents.indexes.aio import SearchIndexClient -from azure.search.documents.indexes.models import ( - AzureOpenAIVectorizerParameters, - KnowledgeBase, - KnowledgeBaseAzureOpenAIModel, - KnowledgeRetrievalLowReasoningEffort, - KnowledgeRetrievalMediumReasoningEffort, - KnowledgeRetrievalMinimalReasoningEffort, - KnowledgeRetrievalOutputMode, - KnowledgeRetrievalReasoningEffort, - KnowledgeSourceReference, - SearchIndexKnowledgeSource, - SearchIndexKnowledgeSourceParameters, -) -from azure.search.documents.models import ( - QueryCaptionType, - QueryType, - VectorizableTextQuery, - VectorizedQuery, -) - -# Type checking imports for optional agentic mode dependencies -if TYPE_CHECKING: - from azure.search.documents.knowledgebases.aio import KnowledgeBaseRetrievalClient - from azure.search.documents.knowledgebases.models import ( - KnowledgeBaseMessage, - KnowledgeBaseMessageTextContent, - KnowledgeBaseRetrievalRequest, - KnowledgeRetrievalIntent, - KnowledgeRetrievalSemanticIntent, - ) - from azure.search.documents.knowledgebases.models import ( - KnowledgeRetrievalLowReasoningEffort as KBRetrievalLowReasoningEffort, - ) - from azure.search.documents.knowledgebases.models import ( - KnowledgeRetrievalMediumReasoningEffort as KBRetrievalMediumReasoningEffort, - ) - from azure.search.documents.knowledgebases.models import ( - KnowledgeRetrievalMinimalReasoningEffort as KBRetrievalMinimalReasoningEffort, - ) - from azure.search.documents.knowledgebases.models import ( - KnowledgeRetrievalOutputMode as KBRetrievalOutputMode, - ) - from azure.search.documents.knowledgebases.models import ( - KnowledgeRetrievalReasoningEffort as KBRetrievalReasoningEffort, - ) - -# Runtime imports for agentic mode (optional dependency) -try: - from azure.search.documents.knowledgebases.aio import KnowledgeBaseRetrievalClient - from azure.search.documents.knowledgebases.models import ( - KnowledgeBaseMessage, - KnowledgeBaseMessageTextContent, - KnowledgeBaseRetrievalRequest, - KnowledgeRetrievalIntent, - KnowledgeRetrievalSemanticIntent, - ) - from azure.search.documents.knowledgebases.models import ( - KnowledgeRetrievalLowReasoningEffort as KBRetrievalLowReasoningEffort, - ) - from azure.search.documents.knowledgebases.models import ( - KnowledgeRetrievalMediumReasoningEffort as KBRetrievalMediumReasoningEffort, - ) - from azure.search.documents.knowledgebases.models import ( - KnowledgeRetrievalMinimalReasoningEffort as KBRetrievalMinimalReasoningEffort, - ) - from azure.search.documents.knowledgebases.models import ( - KnowledgeRetrievalOutputMode as KBRetrievalOutputMode, - ) - from azure.search.documents.knowledgebases.models import ( - KnowledgeRetrievalReasoningEffort as KBRetrievalReasoningEffort, - ) - - _agentic_retrieval_available = True -except ImportError: - _agentic_retrieval_available = False - -if sys.version_info >= (3, 12): - from typing import override # type: ignore # pragma: no cover -else: - from typing_extensions import override # type: ignore[import] # pragma: no cover - -if sys.version_info >= (3, 11): - from typing import Self, TypedDict # pragma: no cover -else: - from typing_extensions import Self, TypedDict # pragma: no cover - -"""Azure AI Search Context Provider for Agent Framework. - -This module provides context providers for Azure AI Search integration with two modes: -- Agentic: Recommended for most scenarios. Uses Knowledge Bases for query planning and - multi-hop reasoning. Slightly slower with more token consumption, but more accurate. -- Semantic: Fast hybrid search (vector + keyword) with semantic ranker. Best for simple - queries where speed is critical. - -See: https://techcommunity.microsoft.com/blog/azure-ai-foundry-blog/foundry-iq-boost-response-relevance-by-36-with-agentic-retrieval/4470720 -""" - - -# Module-level constants -logger = get_logger("agent_framework.azure") -_DEFAULT_AGENTIC_MESSAGE_HISTORY_COUNT = 10 - - -class AzureAISearchSettings(TypedDict, total=False): - """Settings for Azure AI Search Context Provider with auto-loading from environment. - - The settings are first loaded from environment variables with the prefix 'AZURE_SEARCH_'. - If the environment variables are not found, the settings can be loaded from a .env file. - - Keyword Args: - endpoint: Azure AI Search endpoint URL. - Can be set via environment variable AZURE_SEARCH_ENDPOINT. - index_name: Name of the search index. - Can be set via environment variable AZURE_SEARCH_INDEX_NAME. - knowledge_base_name: Name of an existing Knowledge Base (for agentic mode). - Can be set via environment variable AZURE_SEARCH_KNOWLEDGE_BASE_NAME. - api_key: API key for authentication (optional, use managed identity if not provided). - Can be set via environment variable AZURE_SEARCH_API_KEY. - env_file_path: If provided, the .env settings are read from this file path location. - env_file_encoding: The encoding of the .env file, defaults to 'utf-8'. - - Examples: - .. code-block:: python - - from agent_framework_aisearch import AzureAISearchSettings - - # Using environment variables - # Set AZURE_SEARCH_ENDPOINT=https://mysearch.search.windows.net - # Set AZURE_SEARCH_INDEX_NAME=my-index - settings = AzureAISearchSettings() - - # Or passing parameters directly - settings = AzureAISearchSettings( - endpoint="https://mysearch.search.windows.net", - index_name="my-index", - ) - - # Or loading from a .env file - settings = AzureAISearchSettings(env_file_path="path/to/.env") - """ - - endpoint: str | None - index_name: str | None - knowledge_base_name: str | None - api_key: SecretString | None - - -class AzureAISearchContextProvider(ContextProvider): - """Azure AI Search Context Provider with hybrid search and semantic ranking. - - This provider retrieves relevant documents from Azure AI Search to provide context - to the AI agent. It supports two modes: - - - **agentic**: Recommended for most scenarios. Uses Knowledge Bases for query planning - and multi-hop reasoning. Slightly slower with more token consumption, but provides - more accurate results (up to 36% improvement in response relevance). - - **semantic** (default): Fast hybrid search combining vector and keyword search - with semantic reranking. Best for simple queries where speed is critical. - - Examples: - Using environment variables (recommended): - - .. code-block:: python - - from agent_framework_aisearch import AzureAISearchContextProvider - from azure.identity.aio import DefaultAzureCredential - - # Set AZURE_SEARCH_ENDPOINT and AZURE_SEARCH_INDEX_NAME in environment - search_provider = AzureAISearchContextProvider(credential=DefaultAzureCredential()) - - Semantic hybrid search with API key: - - .. code-block:: python - - # Direct API key string - search_provider = AzureAISearchContextProvider( - endpoint="https://mysearch.search.windows.net", - index_name="my-index", - api_key="my-api-key", - mode="semantic", - ) - - Loading from .env file: - - .. code-block:: python - - # Load settings from a .env file - search_provider = AzureAISearchContextProvider( - credential=DefaultAzureCredential(), env_file_path="path/to/.env" - ) - - Agentic retrieval for complex queries: - - .. code-block:: python - - # Use agentic mode for multi-hop reasoning - # Note: azure_openai_resource_url is the OpenAI endpoint for Knowledge Base model calls, - # which is different from azure_ai_project_endpoint (the AI Foundry project endpoint) - search_provider = AzureAISearchContextProvider( - endpoint="https://mysearch.search.windows.net", - index_name="my-index", - credential=DefaultAzureCredential(), - mode="agentic", - azure_openai_resource_url="https://myresource.openai.azure.com", - model_deployment_name="gpt-4o", - knowledge_base_name="my-knowledge-base", - ) - """ - - _DEFAULT_SEARCH_CONTEXT_PROMPT = "Use the following context to answer the question:" - - def __init__( - self, - endpoint: str | None = None, - index_name: str | None = None, - api_key: str | AzureKeyCredential | None = None, - credential: AsyncTokenCredential | None = None, - *, - mode: Literal["semantic", "agentic"] = "semantic", - top_k: int = 5, - semantic_configuration_name: str | None = None, - vector_field_name: str | None = None, - embedding_function: Callable[[str], Awaitable[list[float]]] | None = None, - context_prompt: str | None = None, - # Agentic mode parameters (Knowledge Base) - azure_openai_resource_url: str | None = None, - model_deployment_name: str | None = None, - model_name: str | None = None, - knowledge_base_name: str | None = None, - retrieval_instructions: str | None = None, - azure_openai_api_key: str | None = None, - knowledge_base_output_mode: Literal["extractive_data", "answer_synthesis"] = "extractive_data", - retrieval_reasoning_effort: Literal["minimal", "medium", "low"] = "minimal", - agentic_message_history_count: int = _DEFAULT_AGENTIC_MESSAGE_HISTORY_COUNT, - env_file_path: str | None = None, - env_file_encoding: str | None = None, - ) -> None: - """Initialize Azure AI Search Context Provider. - - Args: - endpoint: Azure AI Search endpoint URL. - Can also be set via environment variable AZURE_SEARCH_ENDPOINT. - index_name: Name of the search index to query. - Can also be set via environment variable AZURE_SEARCH_INDEX_NAME. - api_key: API key for authentication (string or AzureKeyCredential). - Can also be set via environment variable AZURE_SEARCH_API_KEY. - credential: AsyncTokenCredential for managed identity authentication. - Use this for Entra ID authentication instead of api_key. - mode: Search mode - "semantic" for hybrid search with semantic ranking (fast) - or "agentic" for multi-hop reasoning (slower). Default: "semantic". - top_k: Maximum number of documents to retrieve. Only applies to semantic mode. - In agentic mode, the server-side Knowledge Base determines retrieval based on - query complexity and reasoning effort. Default: 5. - semantic_configuration_name: Name of semantic configuration in the index. - Required for semantic ranking. If None, uses index default. - vector_field_name: Name of the vector field in the index for hybrid search. - Required if using vector search. Default: None (keyword search only). - embedding_function: Async function to generate embeddings for vector search. - Signature: async def embed(text: str) -> list[float] - Required if vector_field_name is specified and no server-side vectorization. - context_prompt: Custom prompt to prepend to retrieved context. - Default: "Use the following context to answer the question:" - azure_openai_resource_url: Azure OpenAI resource URL for Knowledge Base model calls. - Required when using agentic mode with index_name (to auto-create Knowledge Base). - Not required when using an existing knowledge_base_name. - Example: "https://myresource.openai.azure.com" - model_deployment_name: Model deployment name in Azure OpenAI for Knowledge Base. - Required when using agentic mode with index_name (to auto-create Knowledge Base). - Not required when using an existing knowledge_base_name. - model_name: The underlying model name (e.g., "gpt-4o", "gpt-4o-mini"). - If not provided, defaults to model_deployment_name. Used for Knowledge Base configuration. - knowledge_base_name: Name of an existing Knowledge Base to use. - Required for agentic mode if not providing index_name. - Supports KBs with any source type (web, blob, index, etc.). - retrieval_instructions: Custom instructions for the Knowledge Base's - retrieval planning. Only used in agentic mode. - azure_openai_api_key: Azure OpenAI API key for Knowledge Base to call the model. - Only needed when using API key authentication instead of managed identity. - knowledge_base_output_mode: Output mode for Knowledge Base retrieval. Only used in agentic mode. - "extractive_data": Returns raw chunks without synthesis (default, recommended for agent integration). - "answer_synthesis": Returns synthesized answer from the LLM. - Some knowledge sources require answer_synthesis mode. Default: "extractive_data". - retrieval_reasoning_effort: Reasoning effort for Knowledge Base query planning. Only used in agentic mode. - "minimal": Fastest, basic query planning. - "medium": Moderate reasoning with some query decomposition. - "low": Lower reasoning effort than medium. - Default: "minimal". - agentic_message_history_count: Number of recent messages from conversation history to send to - the Knowledge Base. This context helps with query planning in agentic mode, allowing the - Knowledge Base to understand the conversation flow and generate better retrieval queries. - There is no technical limit - adjust based on your use case. Default: 10. - env_file_path: Path to environment file for loading settings. - env_file_encoding: Encoding of the environment file. - - Examples: - .. code-block:: python - - from agent_framework_aisearch import AzureAISearchContextProvider - from azure.identity.aio import DefaultAzureCredential - - # Using environment variables - # Set AZURE_SEARCH_ENDPOINT=https://mysearch.search.windows.net - # Set AZURE_SEARCH_INDEX_NAME=my-index - credential = DefaultAzureCredential() - provider = AzureAISearchContextProvider(credential=credential) - - # Or passing parameters directly - provider = AzureAISearchContextProvider( - endpoint="https://mysearch.search.windows.net", - index_name="my-index", - credential=credential, - ) - - # Or loading from a .env file - provider = AzureAISearchContextProvider(credential=credential, env_file_path="path/to/.env") - """ - # Load settings from environment/file - settings = load_settings( - AzureAISearchSettings, - env_prefix="AZURE_SEARCH_", - endpoint=endpoint, - index_name=index_name, - knowledge_base_name=knowledge_base_name, - api_key=api_key if isinstance(api_key, str) else None, - env_file_path=env_file_path, - env_file_encoding=env_file_encoding, - ) - - # Validate required parameters - if not settings.get("endpoint"): - raise ServiceInitializationError( - "Azure AI Search endpoint is required. Set via 'endpoint' parameter " - "or 'AZURE_SEARCH_ENDPOINT' environment variable." - ) - - # Validate index_name and knowledge_base_name based on mode - # Note: settings["field"] / settings.get("field") contains the resolved value (explicit param OR env var) - if mode == "semantic": - # Semantic mode: always requires index_name - if not settings.get("index_name"): - raise ServiceInitializationError( - "Azure AI Search index name is required for semantic mode. " - "Set via 'index_name' parameter or 'AZURE_SEARCH_INDEX_NAME' environment variable." - ) - elif mode == "agentic": - # Agentic mode: requires exactly ONE of index_name or knowledge_base_name - if settings.get("index_name") and settings.get("knowledge_base_name"): - raise ServiceInitializationError( - "For agentic mode, provide either 'index_name' OR 'knowledge_base_name', not both. " - "Use 'index_name' to auto-create a Knowledge Base, or 'knowledge_base_name' to use an existing one." - ) - if not settings.get("index_name") and not settings.get("knowledge_base_name"): - raise ServiceInitializationError( - "For agentic mode, provide either 'index_name' (to auto-create Knowledge Base) " - "or 'knowledge_base_name' (to use existing Knowledge Base). " - "Set via parameters or environment variables " - "AZURE_SEARCH_INDEX_NAME / AZURE_SEARCH_KNOWLEDGE_BASE_NAME." - ) - # If using index_name to create KB, model config is required - if settings.get("index_name") and not model_deployment_name: - raise ServiceInitializationError( - "model_deployment_name is required for agentic mode when creating Knowledge Base from index. " - "This is the Azure OpenAI deployment used by the Knowledge Base for query planning." - ) - - # Determine the credential to use - resolved_credential: AzureKeyCredential | AsyncTokenCredential - if credential: - # AsyncTokenCredential takes precedence - resolved_credential = credential - elif isinstance(api_key, AzureKeyCredential): - resolved_credential = api_key - elif resolved_api_key := settings.get("api_key"): - resolved_credential = AzureKeyCredential(resolved_api_key.get_secret_value()) - else: - raise ServiceInitializationError( - "Azure credential is required. Provide 'api_key' or 'credential' parameter " - "or set 'AZURE_SEARCH_API_KEY' environment variable." - ) - - self.endpoint: str = settings["endpoint"] # type: ignore[assignment] # validated above - self.index_name = settings.get("index_name") - self.credential = resolved_credential - self.mode = mode - self.top_k = top_k - self.semantic_configuration_name = semantic_configuration_name - self.vector_field_name = vector_field_name - self.embedding_function = embedding_function - self.context_prompt = context_prompt or self._DEFAULT_SEARCH_CONTEXT_PROMPT - - # Agentic mode parameters (Knowledge Base) - self.azure_openai_resource_url = azure_openai_resource_url - self.azure_openai_deployment_name = model_deployment_name - # If model_name not provided, default to deployment name - self.model_name = model_name or model_deployment_name - # Use resolved KB name (from explicit param or env var) - self.knowledge_base_name = settings.get("knowledge_base_name") - self.retrieval_instructions = retrieval_instructions - self.azure_openai_api_key = azure_openai_api_key - self.knowledge_base_output_mode = knowledge_base_output_mode - self.retrieval_reasoning_effort = retrieval_reasoning_effort - self.agentic_message_history_count = agentic_message_history_count - - # Determine if using existing Knowledge Base or auto-creating from index - # Since validation ensures exactly one of index_name/knowledge_base_name for agentic mode: - # - knowledge_base_name provided: use existing KB - # - index_name provided: auto-create KB from index - self._use_existing_knowledge_base = False - if mode == "agentic": - if settings.get("knowledge_base_name"): - # Use existing KB directly (supports any source type: web, blob, index, etc.) - self._use_existing_knowledge_base = True - else: - # Auto-generate KB name from index name - self.knowledge_base_name = f"{settings.get('index_name', '')}-kb" - - # Auto-discover vector field if not specified - self._auto_discovered_vector_field = False - self._use_vectorizable_query = False # Will be set to True if server-side vectorization detected - if not vector_field_name and mode == "semantic": - # Attempt to auto-discover vector field from index schema - # This will be done lazily on first search to avoid blocking initialization - pass - - # Validation - if vector_field_name and not embedding_function: - raise ValueError("embedding_function is required when vector_field_name is specified") - - if mode == "agentic": - if not _agentic_retrieval_available: - raise ImportError( - "Agentic retrieval requires azure-search-documents >= 11.7.0b1 with Knowledge Base support. " - "Please upgrade: pip install azure-search-documents>=11.7.0b1" - ) - # Only require OpenAI resource URL if NOT using existing KB - # (existing KB already has its model configuration) - # Note: model_deployment_name is already validated at initialization - if not self._use_existing_knowledge_base and not self.azure_openai_resource_url: - raise ValueError( - "azure_openai_resource_url is required for agentic mode when creating Knowledge Base from index. " - "This should be your Azure OpenAI endpoint (e.g., 'https://myresource.openai.azure.com')" - ) - - # Create search client for semantic mode (only if index_name is available) - self._search_client: SearchClient | None = None - if self.index_name: - self._search_client = SearchClient( - endpoint=self.endpoint, - index_name=self.index_name, - credential=self.credential, - user_agent=AGENT_FRAMEWORK_USER_AGENT, - ) - - # Create index client and retrieval client for agentic mode (Knowledge Base) - self._index_client: SearchIndexClient | None = None - self._retrieval_client: KnowledgeBaseRetrievalClient | None = None - if mode == "agentic": - self._index_client = SearchIndexClient( - endpoint=self.endpoint, - credential=self.credential, - user_agent=AGENT_FRAMEWORK_USER_AGENT, - ) - # Retrieval client will be created after Knowledge Base initialization - - self._knowledge_base_initialized = False - - async def __aenter__(self) -> Self: - """Async context manager entry.""" - return self - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: Any, - ) -> None: - """Async context manager exit - cleanup clients. - - Args: - exc_type: Exception type if an error occurred. - exc_val: Exception value if an error occurred. - exc_tb: Exception traceback if an error occurred. - """ - # Close retrieval client if it was created - if self._retrieval_client is not None: - await self._retrieval_client.close() - self._retrieval_client = None - - @override - async def invoking( - self, - messages: Message | MutableSequence[Message], - **kwargs: Any, - ) -> Context: - """Retrieve relevant context from Azure AI Search before model invocation. - - Args: - messages: User messages to use for context retrieval. - **kwargs: Additional arguments (unused). - - Returns: - Context object with retrieved documents as messages. - """ - # Convert to list and filter to USER/ASSISTANT messages with text only - messages_list = [messages] if isinstance(messages, Message) else list(messages) - - def get_role_value(role: str | Any) -> str: - return role.value if hasattr(role, "value") else str(role) - - filtered_messages = [ - msg - for msg in messages_list - if msg and msg.text and msg.text.strip() and get_role_value(msg.role) in ["user", "assistant"] - ] - - if not filtered_messages: - return Context() - - # Perform search based on mode - if self.mode == "semantic": - # Semantic mode: flatten messages to single query - query = "\n".join(msg.text for msg in filtered_messages) - search_result_parts = await self._semantic_search(query) - else: # agentic - # Agentic mode: pass recent messages as conversation history - recent_messages = filtered_messages[-self.agentic_message_history_count :] - search_result_parts = await self._agentic_search(recent_messages) - - # Format results as context - return multiple messages for each result part - if not search_result_parts: - return Context() - - # Create context messages: first message with prompt, then one message per result part - context_messages = [Message(role="user", text=self.context_prompt)] - context_messages.extend([Message(role="user", text=part) for part in search_result_parts]) - - return Context(messages=context_messages) - - def _find_vector_fields(self, index: Any) -> list[str]: - """Find all fields that can store vectors (have dimensions defined). - - Args: - index: SearchIndex object from Azure Search. - - Returns: - List of vector field names. - """ - return [ - field.name - for field in index.fields - if field.vector_search_dimensions is not None and field.vector_search_dimensions > 0 - ] - - def _find_vectorizable_fields(self, index: Any, vector_fields: list[str]) -> list[str]: - """Find vector fields that have auto-vectorization configured. - - These are fields that have a vectorizer in their profile, meaning the index - can automatically vectorize text queries without needing a client-side embedding function. - - Args: - index: SearchIndex object from Azure Search. - vector_fields: List of vector field names. - - Returns: - List of vectorizable field names (subset of vector_fields). - """ - vectorizable_fields: list[str] = [] - - # Check if index has vector search configuration - if not index.vector_search or not index.vector_search.profiles: - return vectorizable_fields - - # For each vector field, check if it has a vectorizer configured - for field in index.fields: - if field.name in vector_fields and field.vector_search_profile_name: - # Find the profile for this field - profile = next( - (p for p in index.vector_search.profiles if p.name == field.vector_search_profile_name), None - ) - - if profile and hasattr(profile, "vectorizer_name") and profile.vectorizer_name: - # This field has server-side vectorization configured - vectorizable_fields.append(field.name) - - return vectorizable_fields - - async def _auto_discover_vector_field(self) -> None: - """Auto-discover vector field from index schema. - - Attempts to find vector fields in the index and detect which have server-side - vectorization configured. Prioritizes vectorizable fields (which can auto-embed text) - over regular vector fields (which require client-side embedding). - """ - if self._auto_discovered_vector_field or self.vector_field_name: - return # Already discovered or manually specified - - try: - # Use existing index client or create temporary one - if not self._index_client: - self._index_client = SearchIndexClient( - endpoint=self.endpoint, - credential=self.credential, - user_agent=AGENT_FRAMEWORK_USER_AGENT, - ) - index_client = self._index_client - - # Get index schema (index_name is guaranteed to be set for semantic mode) - if not self.index_name: - logger.warning("Cannot auto-discover vector field: index_name is not set.") - self._auto_discovered_vector_field = True - return - - index = await index_client.get_index(self.index_name) - - # Step 1: Find all vector fields - vector_fields = self._find_vector_fields(index) - - if not vector_fields: - # No vector fields found - keyword search only - logger.info(f"No vector fields found in index '{self.index_name}'. Using keyword-only search.") - self._auto_discovered_vector_field = True - return - - # Step 2: Find which vector fields have server-side vectorization - vectorizable_fields = self._find_vectorizable_fields(index, vector_fields) - - # Step 3: Decide which field to use - if vectorizable_fields: - # Prefer vectorizable fields (server-side embedding) - if len(vectorizable_fields) == 1: - self.vector_field_name = vectorizable_fields[0] - self._auto_discovered_vector_field = True - self._use_vectorizable_query = True # Use VectorizableTextQuery - logger.info( - f"Auto-discovered vectorizable field '{self.vector_field_name}' " - f"with server-side vectorization. No embedding_function needed." - ) - else: - # Multiple vectorizable fields - logger.warning( - f"Multiple vectorizable fields found: {vectorizable_fields}. " - f"Please specify vector_field_name explicitly. Using keyword-only search." - ) - elif len(vector_fields) == 1: - # Single vector field without vectorizer - needs client-side embedding - self.vector_field_name = vector_fields[0] - self._auto_discovered_vector_field = True - self._use_vectorizable_query = False - - if not self.embedding_function: - logger.warning( - f"Auto-discovered vector field '{self.vector_field_name}' without server-side vectorization. " - f"Provide embedding_function for vector search, or it will fall back to keyword-only search." - ) - self.vector_field_name = None - else: - # Multiple vector fields without vectorizers - logger.warning( - f"Multiple vector fields found: {vector_fields}. " - f"Please specify vector_field_name explicitly. Using keyword-only search." - ) - - except Exception as e: - # Log warning but continue with keyword search - logger.warning(f"Failed to auto-discover vector field: {e}. Using keyword-only search.") - - self._auto_discovered_vector_field = True # Mark as attempted - - async def _semantic_search(self, query: str) -> list[str]: - """Perform semantic hybrid search with semantic ranking. - - This is the recommended mode for most use cases. It combines: - - Vector search (if embedding_function provided) - - Keyword search (BM25) - - Semantic reranking (if semantic_configuration_name provided) - - Args: - query: Search query text. - - Returns: - List of formatted search result strings, one per document. - """ - # Auto-discover vector field if not already done - await self._auto_discover_vector_field() - - vector_queries: list[VectorizableTextQuery | VectorizedQuery] = [] - - # Build vector query based on server-side vectorization or client-side embedding - if self.vector_field_name: - # Use larger k for vector query when semantic reranker is enabled for better ranking quality - vector_k = max(self.top_k, 50) if self.semantic_configuration_name else self.top_k - - if self._use_vectorizable_query: - # Server-side vectorization: Index will auto-embed the text query - vector_queries = [ - VectorizableTextQuery( - text=query, - k_nearest_neighbors=vector_k, - fields=self.vector_field_name, - ) - ] - elif self.embedding_function: - # Client-side embedding: We provide the vector - query_vector = await self.embedding_function(query) - vector_queries = [ - VectorizedQuery( - vector=query_vector, - k_nearest_neighbors=vector_k, - fields=self.vector_field_name, - ) - ] - # else: vector_field_name is set but no vectorization available - skip vector search - - # Build search parameters - search_params: dict[str, Any] = { - "search_text": query, - "top": self.top_k, - } - - if vector_queries: - search_params["vector_queries"] = vector_queries - - # Add semantic ranking if configured - if self.semantic_configuration_name: - search_params["query_type"] = QueryType.SEMANTIC - search_params["semantic_configuration_name"] = self.semantic_configuration_name - search_params["query_caption"] = QueryCaptionType.EXTRACTIVE - - # Execute search (search client is guaranteed to exist for semantic mode) - if not self._search_client: - raise RuntimeError("Search client is not initialized. This should not happen in semantic mode.") - - results = await self._search_client.search(**search_params) # type: ignore[reportUnknownVariableType] - - # Format results with citations - formatted_results: list[str] = [] - async for doc in results: # type: ignore[reportUnknownVariableType] - # Extract document ID for citation - doc_id = doc.get("id") or doc.get("@search.id") # type: ignore[reportUnknownVariableType] - - # Use full document chunks with citation - doc_text: str = self._extract_document_text(doc, doc_id=doc_id) # type: ignore[reportUnknownArgumentType] - if doc_text: - formatted_results.append(doc_text) # type: ignore[reportUnknownArgumentType] - - return formatted_results - - async def _ensure_knowledge_base(self) -> None: - """Ensure Knowledge Base and knowledge source are created or use existing KB. - - This method is idempotent - it will only create resources if they don't exist. - - Note: Azure SDK uses KnowledgeAgent classes internally, but the feature - is marketed as "Knowledge Bases" in Azure AI Search. - """ - if self._knowledge_base_initialized: - return - - # Runtime validation - if not self.knowledge_base_name: - raise ValueError("knowledge_base_name is required for agentic mode") - - knowledge_base_name = self.knowledge_base_name - - # Path 1: Use existing Knowledge Base directly (no index needed) - # This supports KB with any source type (web, blob, index, etc.) - if self._use_existing_knowledge_base: - # Just create the retrieval client - KB already exists with its own sources - if _agentic_retrieval_available and self._retrieval_client is None: - self._retrieval_client = KnowledgeBaseRetrievalClient( - endpoint=self.endpoint, - knowledge_base_name=knowledge_base_name, - credential=self.credential, - user_agent=AGENT_FRAMEWORK_USER_AGENT, - ) - self._knowledge_base_initialized = True - return - - # Path 2: Auto-create Knowledge Base from search index - # Requires index_client and OpenAI configuration - if not self._index_client: - raise ValueError("Index client is required when creating Knowledge Base from index") - if not self.azure_openai_resource_url: - raise ValueError("azure_openai_resource_url is required when creating Knowledge Base from index") - if not self.azure_openai_deployment_name: - raise ValueError("model_deployment_name is required when creating Knowledge Base from index") - if not self.index_name: - raise ValueError("index_name is required when creating Knowledge Base from index") - - # Step 1: Create or get knowledge source from index - knowledge_source_name = f"{self.index_name}-source" - - try: - # Try to get existing knowledge source - await self._index_client.get_knowledge_source(knowledge_source_name) - except ResourceNotFoundError: - # Create new knowledge source if it doesn't exist - knowledge_source = SearchIndexKnowledgeSource( - name=knowledge_source_name, - description=f"Knowledge source for {self.index_name} search index", - search_index_parameters=SearchIndexKnowledgeSourceParameters( - search_index_name=self.index_name, - ), - ) - await self._index_client.create_knowledge_source(knowledge_source) - - # Step 2: Create or update Knowledge Base - # Always create/update to ensure configuration is current - aoai_params = AzureOpenAIVectorizerParameters( - resource_url=self.azure_openai_resource_url, - deployment_name=self.azure_openai_deployment_name, - model_name=self.model_name, - api_key=self.azure_openai_api_key, - ) - - # Map output mode string to SDK enum - output_mode = ( - KnowledgeRetrievalOutputMode.EXTRACTIVE_DATA - if self.knowledge_base_output_mode == "extractive_data" - else KnowledgeRetrievalOutputMode.ANSWER_SYNTHESIS - ) - - # Map reasoning effort string to SDK class - reasoning_effort_map: dict[str, KnowledgeRetrievalReasoningEffort] = { - "minimal": KnowledgeRetrievalMinimalReasoningEffort(), - "medium": KnowledgeRetrievalMediumReasoningEffort(), - "low": KnowledgeRetrievalLowReasoningEffort(), - } - reasoning_effort = reasoning_effort_map[self.retrieval_reasoning_effort] - - knowledge_base = KnowledgeBase( - name=knowledge_base_name, - description=f"Knowledge Base for multi-hop retrieval across {self.index_name}", - knowledge_sources=[ - KnowledgeSourceReference( - name=knowledge_source_name, - ) - ], - models=[KnowledgeBaseAzureOpenAIModel(azure_open_ai_parameters=aoai_params)], - output_mode=output_mode, - retrieval_reasoning_effort=reasoning_effort, - ) - await self._index_client.create_or_update_knowledge_base(knowledge_base) - - self._knowledge_base_initialized = True - - # Create retrieval client now that Knowledge Base is initialized - if _agentic_retrieval_available and self._retrieval_client is None: - self._retrieval_client = KnowledgeBaseRetrievalClient( - endpoint=self.endpoint, - knowledge_base_name=knowledge_base_name, - credential=self.credential, - user_agent=AGENT_FRAMEWORK_USER_AGENT, - ) - - async def _agentic_search(self, messages: list[Message]) -> list[str]: - """Perform agentic retrieval with multi-hop reasoning using Knowledge Bases. - - This mode uses query planning and is slightly slower than semantic search, - but provides more accurate results through intelligent retrieval. - - This method uses Azure AI Search Knowledge Bases which: - 1. Analyze the query and plan sub-queries - 2. Retrieve relevant documents across multiple sources - 3. Perform multi-hop reasoning with an LLM - 4. Synthesize a comprehensive answer with references - - Args: - messages: Conversation history to use for retrieval context. - - Returns: - List of answer parts from the Knowledge Base, one per content item. - """ - # Ensure Knowledge Base is initialized - await self._ensure_knowledge_base() - - # Map reasoning effort string to SDK class (for retrieval requests) - reasoning_effort_map: dict[str, KBRetrievalReasoningEffort] = { - "minimal": KBRetrievalMinimalReasoningEffort(), - "medium": KBRetrievalMediumReasoningEffort(), - "low": KBRetrievalLowReasoningEffort(), - } - reasoning_effort = reasoning_effort_map[self.retrieval_reasoning_effort] - - # Map output mode string to SDK enum (for retrieval requests) - output_mode = ( - KBRetrievalOutputMode.EXTRACTIVE_DATA - if self.knowledge_base_output_mode == "extractive_data" - else KBRetrievalOutputMode.ANSWER_SYNTHESIS - ) - - # For minimal reasoning, use intents API; for medium/low, use messages API - if self.retrieval_reasoning_effort == "minimal": - # Minimal reasoning uses intents with a single search query - query = "\n".join(msg.text for msg in messages if msg.text) - intents: list[KnowledgeRetrievalIntent] = [KnowledgeRetrievalSemanticIntent(search=query)] - retrieval_request = KnowledgeBaseRetrievalRequest( - intents=intents, - retrieval_reasoning_effort=reasoning_effort, - output_mode=output_mode, - include_activity=True, - ) - else: - # Medium/low reasoning uses messages with conversation history - kb_messages = [ - KnowledgeBaseMessage( - role=msg.role if hasattr(msg.role, "value") else str(msg.role), - content=[KnowledgeBaseMessageTextContent(text=msg.text)], - ) - for msg in messages - if msg.text - ] - retrieval_request = KnowledgeBaseRetrievalRequest( - messages=kb_messages, - retrieval_reasoning_effort=reasoning_effort, - output_mode=output_mode, - include_activity=True, - ) - - # Use reusable retrieval client - if not self._retrieval_client: - raise RuntimeError("Retrieval client not initialized. Ensure Knowledge Base is set up correctly.") - - # Perform retrieval via Knowledge Base - retrieval_result = await self._retrieval_client.retrieve(retrieval_request=retrieval_request) - - # Extract answer parts from response - if retrieval_result.response and len(retrieval_result.response) > 0: - # Get the assistant's response (last message) - assistant_message = retrieval_result.response[-1] - if assistant_message.content: - # Extract all text content items as separate parts - answer_parts: list[str] = [] - for content_item in assistant_message.content: - # Check if this is a text content item - if isinstance(content_item, KnowledgeBaseMessageTextContent) and content_item.text: - answer_parts.append(content_item.text) - - if answer_parts: - return answer_parts - - # Fallback if no answer generated - return ["No results found from Knowledge Base."] - - def _extract_document_text(self, doc: dict[str, Any], doc_id: str | None = None) -> str: - """Extract readable text from a search document with optional citation. - - Args: - doc: Search result document. - doc_id: Optional document ID for citation. - - Returns: - Formatted document text with citation if doc_id provided. - """ - # Try common text field names - text = "" - for field in ["content", "text", "description", "body", "chunk"]: - if doc.get(field): - text = str(doc[field]) - break - - # Fallback: concatenate all string fields - if not text: - text_parts: list[str] = [] - for key, value in doc.items(): - if isinstance(value, str) and not key.startswith("@") and key != "id": - text_parts.append(f"{key}: {value}") - text = " | ".join(text_parts) if text_parts else "" - - # Add citation if document ID provided - if doc_id and text: - return f"[Source: {doc_id}] {text}" - return text diff --git a/python/packages/core/agent_framework/__init__.py b/python/packages/core/agent_framework/__init__.py index 041aa17306..48095326de 100644 --- a/python/packages/core/agent_framework/__init__.py +++ b/python/packages/core/agent_framework/__init__.py @@ -13,11 +13,9 @@ from ._clients import * # noqa: F403 from ._logging import * # noqa: F403 from ._mcp import * # noqa: F403 -from ._memory import * # noqa: F403 from ._middleware import * # noqa: F403 from ._sessions import * # noqa: F403 from ._telemetry import * # noqa: F403 -from ._threads import * # noqa: F403 from ._tools import * # noqa: F403 from ._types import * # noqa: F403 from ._workflows import * # noqa: F403 diff --git a/python/packages/core/agent_framework/_memory.py b/python/packages/core/agent_framework/_memory.py deleted file mode 100644 index f6c2bd6403..0000000000 --- a/python/packages/core/agent_framework/_memory.py +++ /dev/null @@ -1,181 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -from __future__ import annotations - -import sys -from abc import ABC, abstractmethod -from collections.abc import MutableSequence, Sequence -from types import TracebackType -from typing import TYPE_CHECKING, Any, Final - -from ._types import Message - -if TYPE_CHECKING: - from ._tools import FunctionTool - -if sys.version_info >= (3, 11): - from typing import Self # pragma: no cover -else: - from typing_extensions import Self # pragma: no cover - -# region Context - -__all__ = ["Context", "ContextProvider"] - - -class Context: - """A class containing any context that should be provided to the AI model as supplied by a ContextProvider. - - Each ContextProvider has the ability to provide its own context for each invocation. - The Context class contains the additional context supplied by the ContextProvider. - This context will be combined with context supplied by other providers before being passed to the AI model. - This context is per invocation, and will not be stored as part of the chat history. - - Examples: - .. code-block:: python - - from agent_framework import Context, Message - - # Create context with instructions - context = Context( - instructions="Use a professional tone when responding.", - messages=[Message(content="Previous context", role="user")], - tools=[my_tool], - ) - - # Access context properties - print(context.instructions) - print(len(context.messages)) - """ - - def __init__( - self, - instructions: str | None = None, - messages: Sequence[Message] | None = None, - tools: Sequence[FunctionTool] | None = None, - ): - """Create a new Context object. - - Args: - instructions: The instructions to provide to the AI model. - messages: The list of messages to include in the context. - tools: The list of tools to provide to this run. - """ - self.instructions = instructions - self.messages: Sequence[Message] = messages or [] - self.tools: Sequence[FunctionTool] = tools or [] - - -# region ContextProvider - - -class ContextProvider(ABC): - """Base class for all context providers. - - A context provider is a component that can be used to enhance the AI's context management. - It can listen to changes in the conversation and provide additional context to the AI model - just before invocation. - - Note: - ContextProvider is an abstract base class. You must subclass it and implement - the ``invoking()`` method to create a custom context provider. Ideally, you should - also implement the ``invoked()`` and ``thread_created()`` methods to track conversation - state, but these are optional. - - Examples: - .. code-block:: python - - from agent_framework import ContextProvider, Context, Message - - - class CustomContextProvider(ContextProvider): - async def invoking(self, messages, **kwargs): - # Add custom instructions before each invocation - return Context(instructions="Always be concise and helpful.", messages=[], tools=[]) - - - # Use with a chat agent - async with CustomContextProvider() as provider: - agent = Agent(client=client, name="assistant", context_provider=provider) - """ - - # Default prompt to be used by all context providers when assembling memories/instructions - DEFAULT_CONTEXT_PROMPT: Final[str] = "## Memories\nConsider the following memories when answering user questions:" - - async def thread_created(self, thread_id: str | None) -> None: - """Called just after a new thread is created. - - Implementers can use this method to perform any operations required at the creation - of a new thread. For example, checking long-term storage for any data that is relevant - to the current session. - - Args: - thread_id: The ID of the new thread. - """ - pass - - async def invoked( - self, - request_messages: Message | Sequence[Message], - response_messages: Message | Sequence[Message] | None = None, - invoke_exception: Exception | None = None, - **kwargs: Any, - ) -> None: - """Called after the agent has received a response from the underlying inference service. - - You can inspect the request and response messages, and update the state of the context provider. - - Args: - request_messages: The messages that were sent to the model/agent. - response_messages: The messages that were returned by the model/agent. - invoke_exception: The exception that was thrown, if any. - - Keyword Args: - kwargs: Additional keyword arguments (not used at present). - """ - pass - - @abstractmethod - async def invoking(self, messages: Message | MutableSequence[Message], **kwargs: Any) -> Context: - """Called just before the model/agent is invoked. - - Implementers can load any additional context required at this time, - and they should return any context that should be passed to the agent. - - Args: - messages: The most recent messages that the agent is being invoked with. - - Keyword Args: - kwargs: Additional keyword arguments (not used at present). - - Returns: - A Context object containing instructions, messages, and tools to include. - """ - pass - - async def __aenter__(self) -> Self: - """Enter the async context manager. - - Override this method to perform any setup operations when the context provider is entered. - - Returns: - The ContextProvider instance for chaining. - """ - return self - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - """Exit the async context manager. - - Override this method to perform any cleanup operations when the context provider is exited. - - Args: - exc_type: The exception type if an exception occurred, None otherwise. - exc_val: The exception value if an exception occurred, None otherwise. - exc_tb: The exception traceback if an exception occurred, None otherwise. - """ - pass diff --git a/python/packages/core/agent_framework/_serialization.py b/python/packages/core/agent_framework/_serialization.py index 8588e0be5a..1259ca2a1a 100644 --- a/python/packages/core/agent_framework/_serialization.py +++ b/python/packages/core/agent_framework/_serialization.py @@ -166,45 +166,22 @@ class SerializationMixin: during deserialization via the ``dependencies`` parameter. Examples: - **Nested object serialization with agent thread management:** + **Nested object serialization:** .. code-block:: python from agent_framework import Message - from agent_framework._threads import AgentThreadState, ChatMessageStoreState + from agent_framework._sessions import AgentSession - # ChatMessageStoreState handles nested Message serialization - store_state = ChatMessageStoreState( - messages=[ - Message(role="user", text="Hello agent"), - Message(role="assistant", text="Hi! How can I help?"), - ] - ) - - # Nested serialization: messages are automatically converted to dicts - store_dict = store_state.to_dict() - # Result: { - # "type": "chat_message_store_state", - # "messages": [ - # {"type": "chat_message", "role": {...}, "contents": [...]}, - # {"type": "chat_message", "role": {...}, "contents": [...]} - # ] - # } - - # AgentThreadState contains nested ChatMessageStoreState - thread_state = AgentThreadState(chat_message_store_state=store_state) + # AgentSession uses SerializationMixin for state serialization + session = AgentSession(session_id="test") - # Deep serialization: nested SerializationMixin objects are handled automatically - thread_dict = thread_state.to_dict() - # The chat_message_store_state and its nested messages are all serialized + # Serialization produces a clean dict representation + session_dict = session.to_dict() - # Reconstruction from nested dictionaries with automatic type conversion - # The __init__ method handles MutableMapping -> object conversion: - reconstructed = AgentThreadState.from_dict({ - "chat_message_store_state": {"messages": [{"role": "user", "text": "Hello again"}]} - }) - # chat_message_store_state becomes ChatMessageStoreState instance automatically + # Reconstruction from dictionaries + restored = AgentSession.from_dict(session_dict) **Framework tools with exclusion patterns:** diff --git a/python/packages/core/agent_framework/_threads.py b/python/packages/core/agent_framework/_threads.py deleted file mode 100644 index 83b33519d8..0000000000 --- a/python/packages/core/agent_framework/_threads.py +++ /dev/null @@ -1,507 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -from __future__ import annotations - -from collections.abc import MutableMapping, Sequence -from typing import Any, Protocol, TypeVar - -from ._memory import ContextProvider -from ._serialization import SerializationMixin -from ._types import Message -from .exceptions import AgentThreadException - -__all__ = ["AgentThread", "ChatMessageStore", "ChatMessageStoreProtocol"] - - -class ChatMessageStoreProtocol(Protocol): - """Defines methods for storing and retrieving chat messages associated with a specific thread. - - Implementations of this protocol are responsible for managing the storage of chat messages, - including handling large volumes of data by truncating or summarizing messages as necessary. - - Examples: - .. code-block:: python - - from agent_framework import Message - - - class MyMessageStore: - def __init__(self): - self._messages = [] - - async def list_messages(self) -> list[Message]: - return self._messages - - async def add_messages(self, messages: Sequence[Message]) -> None: - self._messages.extend(messages) - - @classmethod - async def deserialize(cls, serialized_store_state, **kwargs): - store = cls() - store._messages = serialized_store_state.get("messages", []) - return store - - async def update_from_state(self, serialized_store_state, **kwargs) -> None: - self._messages = serialized_store_state.get("messages", []) - - async def serialize(self, **kwargs): - return {"messages": self._messages} - - - # Use the custom store - store = MyMessageStore() - """ - - async def list_messages(self) -> list[Message]: - """Gets all the messages from the store that should be used for the next agent invocation. - - Messages are returned in ascending chronological order, with the oldest message first. - - If the messages stored in the store become very large, it is up to the store to - truncate, summarize or otherwise limit the number of messages returned. - - When using implementations of ``ChatMessageStoreProtocol``, a new one should be created for each thread - since they may contain state that is specific to a thread. - """ - ... - - async def add_messages(self, messages: Sequence[Message]) -> None: - """Adds messages to the store. - - Args: - messages: The sequence of Message objects to add to the store. - """ - ... - - @classmethod - async def deserialize( - cls, serialized_store_state: MutableMapping[str, Any], **kwargs: Any - ) -> ChatMessageStoreProtocol: - """Creates a new instance of the store from previously serialized state. - - This method, together with ``serialize()`` can be used to save and load messages from a persistent store - if this store only has messages in memory. - - Args: - serialized_store_state: The previously serialized state data containing messages. - - Keyword Args: - **kwargs: Additional arguments for deserialization. - - Returns: - A new instance of the store populated with messages from the serialized state. - """ - ... - - async def update_from_state(self, serialized_store_state: MutableMapping[str, Any], **kwargs: Any) -> None: - """Update the current ChatMessageStore instance from serialized state data. - - Args: - serialized_store_state: Previously serialized state data containing messages. - - Keyword Args: - kwargs: Additional arguments for deserialization. - """ - ... - - async def serialize(self, **kwargs: Any) -> dict[str, Any]: - """Serializes the current object's state. - - This method, together with ``deserialize()`` can be used to save and load messages from a persistent store - if this store only has messages in memory. - - Keyword Args: - kwargs: Additional arguments for serialization. - - Returns: - The serialized state data that can be used with ``deserialize()``. - """ - ... - - -class ChatMessageStoreState(SerializationMixin): - """State model for serializing and deserializing chat message store data. - - Attributes: - messages: List of chat messages stored in the message store. - """ - - def __init__( - self, - messages: Sequence[Message] | Sequence[MutableMapping[str, Any]] | None = None, - **kwargs: Any, - ) -> None: - """Create the store state. - - Args: - messages: a list of messages or a list of the dict representation of messages. - - Keyword Args: - **kwargs: not used for this, but might be used by subclasses. - - """ - if not messages: - self.messages: list[Message] = [] - return - if not isinstance(messages, list): - raise TypeError("Messages should be a list") - new_messages: list[Message] = [] - for msg in messages: - if isinstance(msg, Message): - new_messages.append(msg) - else: - new_messages.append(Message.from_dict(msg)) - self.messages = new_messages - - -class AgentThreadState(SerializationMixin): - """State model for serializing and deserializing thread information.""" - - def __init__( - self, - *, - service_thread_id: str | None = None, - chat_message_store_state: ChatMessageStoreState | MutableMapping[str, Any] | None = None, - ) -> None: - """Create a AgentThread state. - - Keyword Args: - service_thread_id: Optional ID of the thread managed by the agent service. - chat_message_store_state: Optional serialized state of the chat message store. - """ - if service_thread_id is not None and chat_message_store_state is not None: - raise AgentThreadException("A thread cannot have both a service_thread_id and a chat_message_store.") - self.service_thread_id = service_thread_id - self.chat_message_store_state: ChatMessageStoreState | None = None - if chat_message_store_state is not None: - if isinstance(chat_message_store_state, dict): - self.chat_message_store_state = ChatMessageStoreState.from_dict(chat_message_store_state) - elif isinstance(chat_message_store_state, ChatMessageStoreState): - self.chat_message_store_state = chat_message_store_state - else: - raise TypeError("Could not parse ChatMessageStoreState.") - - -ChatMessageStoreT = TypeVar("ChatMessageStoreT", bound="ChatMessageStore") - - -class ChatMessageStore: - """An in-memory implementation of ChatMessageStoreProtocol that stores messages in a list. - - This implementation provides a simple, list-based storage for chat messages - with support for serialization and deserialization. It implements all the - required methods of the ``ChatMessageStoreProtocol`` protocol. - - The store maintains messages in memory and provides methods to serialize - and deserialize the state for persistence purposes. - - Examples: - .. code-block:: python - - from agent_framework import ChatMessageStore, Message - - # Create an empty store - store = ChatMessageStore() - - # Add messages - message = Message(role="user", text="Hello") - await store.add_messages([message]) - - # Retrieve messages - messages = await store.list_messages() - - # Serialize for persistence - state = await store.serialize() - - # Deserialize from saved state - restored_store = await ChatMessageStore.deserialize(state) - """ - - def __init__(self, messages: Sequence[Message] | None = None): - """Create a ChatMessageStore for use in a thread. - - Args: - messages: The messages to store. - """ - self.messages = list(messages) if messages else [] - - async def add_messages(self, messages: Sequence[Message]) -> None: - """Add messages to the store. - - Args: - messages: Sequence of Message objects to add to the store. - """ - self.messages.extend(messages) - - async def list_messages(self) -> list[Message]: - """Get all messages from the store in chronological order. - - Returns: - List of Message objects, ordered from oldest to newest. - """ - return self.messages - - @classmethod - async def deserialize( - cls: type[ChatMessageStoreT], serialized_store_state: MutableMapping[str, Any], **kwargs: Any - ) -> ChatMessageStoreT: - """Create a new ChatMessageStore instance from serialized state data. - - Args: - serialized_store_state: Previously serialized state data containing messages. - - Keyword Args: - **kwargs: Additional arguments for deserialization. - - Returns: - A new ChatMessageStore instance populated with messages from the serialized state. - """ - state = ChatMessageStoreState.from_dict(serialized_store_state, **kwargs) - if state.messages: - return cls(messages=state.messages) - return cls() - - async def update_from_state(self, serialized_store_state: MutableMapping[str, Any], **kwargs: Any) -> None: - """Update the current ChatMessageStore instance from serialized state data. - - Args: - serialized_store_state: Previously serialized state data containing messages. - - Keyword Args: - **kwargs: Additional arguments for deserialization. - """ - if not serialized_store_state: - return - state = ChatMessageStoreState.from_dict(serialized_store_state, **kwargs) - if state.messages: - self.messages = state.messages - - async def serialize(self, **kwargs: Any) -> dict[str, Any]: - """Serialize the current store state for persistence. - - Keyword Args: - **kwargs: Additional arguments for serialization. - - Returns: - Serialized state data that can be used with deserialize_state. - """ - state = ChatMessageStoreState(messages=self.messages) - return state.to_dict() - - -AgentThreadT = TypeVar("AgentThreadT", bound="AgentThread") - - -class AgentThread: - """The Agent thread class, this can represent both a locally managed thread or a thread managed by the service. - - An ``AgentThread`` maintains the conversation state and message history for an agent interaction. - It can either use a service-managed thread (via ``service_thread_id``) or a local message store - (via ``message_store``), but not both. - - Examples: - .. code-block:: python - - from agent_framework import Agent, ChatMessageStore - from agent_framework.openai import OpenAIChatClient - - client = OpenAIChatClient(model="gpt-4o") - - # Create agent with service-managed threads using a service_thread_id - service_agent = Agent(name="assistant", client=client) - service_thread = await service_agent.get_new_thread(service_thread_id="thread_abc123") - - # Create agent with service-managed threads using conversation_id - conversation_agent = Agent(name="assistant", client=client, conversation_id="thread_abc123") - conversation_thread = await conversation_agent.get_new_thread() - - # Create agent with custom message store factory - local_agent = Agent(name="assistant", client=client, chat_message_store_factory=ChatMessageStore) - local_thread = await local_agent.get_new_thread() - - # Serialize and restore thread state - state = await local_thread.serialize() - restored_thread = await local_agent.deserialize_thread(state) - """ - - def __init__( - self, - *, - service_thread_id: str | None = None, - message_store: ChatMessageStoreProtocol | None = None, - context_provider: ContextProvider | None = None, - ) -> None: - """Initialize an AgentThread, do not use this method manually, always use: ``agent.get_new_thread()``. - - Args: - service_thread_id: The optional ID of the thread managed by the agent service. - message_store: The optional ChatMessageStore implementation for managing chat messages. - context_provider: The optional ContextProvider for the thread. - - Note: - Either ``service_thread_id`` or ``message_store`` may be set, but not both. - """ - if service_thread_id is not None and message_store is not None: - raise AgentThreadException("Only the service_thread_id or message_store may be set, but not both.") - - self._service_thread_id = service_thread_id - self._message_store = message_store - self.context_provider = context_provider - - @property - def is_initialized(self) -> bool: - """Indicates if the thread is initialized. - - This means either the ``service_thread_id`` or the ``message_store`` is set. - """ - return self._service_thread_id is not None or self._message_store is not None - - @property - def service_thread_id(self) -> str | None: - """Gets the ID of the current thread to support cases where the thread is owned by the agent service.""" - return self._service_thread_id - - @service_thread_id.setter - def service_thread_id(self, service_thread_id: str | None) -> None: - """Sets the ID of the current thread to support cases where the thread is owned by the agent service. - - Note: - Either ``service_thread_id`` or ``message_store`` may be set, but not both. - """ - if service_thread_id is None: - return - - if self._message_store is not None: - raise AgentThreadException( - "Only the service_thread_id or message_store may be set, " - "but not both and switching from one to another is not supported." - ) - self._service_thread_id = service_thread_id - - @property - def message_store(self) -> ChatMessageStoreProtocol | None: - """Gets the ``ChatMessageStoreProtocol`` used by this thread.""" - return self._message_store - - @message_store.setter - def message_store(self, message_store: ChatMessageStoreProtocol | None) -> None: - """Sets the ``ChatMessageStoreProtocol`` used by this thread. - - Note: - Either ``service_thread_id`` or ``message_store`` may be set, but not both. - """ - if message_store is None: - return - - if self._service_thread_id is not None: - raise AgentThreadException( - "Only the service_thread_id or message_store may be set, " - "but not both and switching from one to another is not supported." - ) - - self._message_store = message_store - - async def on_new_messages(self, new_messages: Message | Sequence[Message]) -> None: - """Invoked when a new message has been contributed to the chat by any participant. - - Args: - new_messages: The new Message or sequence of Message objects to add to the thread. - """ - if self._service_thread_id is not None: - # If the thread messages are stored in the service there is nothing to do here, - # since invoking the service should already update the thread. - return - if self._message_store is None: - # If there is no conversation id, and no store we can - # create a default in memory store. - self._message_store = ChatMessageStore() - # If a store has been provided, we need to add the messages to the store. - if isinstance(new_messages, Message): - new_messages = [new_messages] - await self._message_store.add_messages(new_messages) - - async def serialize(self, **kwargs: Any) -> dict[str, Any]: - """Serializes the current object's state. - - Keyword Args: - **kwargs: Arguments for serialization. - """ - chat_message_store_state = None - if self._message_store is not None: - chat_message_store_state = await self._message_store.serialize(**kwargs) - - state = AgentThreadState( - service_thread_id=self._service_thread_id, chat_message_store_state=chat_message_store_state - ) - return state.to_dict(exclude_none=False) - - @classmethod - async def deserialize( - cls: type[AgentThreadT], - serialized_thread_state: MutableMapping[str, Any], - *, - message_store: ChatMessageStoreProtocol | None = None, - **kwargs: Any, - ) -> AgentThreadT: - """Deserializes the state from a dictionary into a new AgentThread instance. - - Args: - serialized_thread_state: The serialized thread state as a dictionary. - - Keyword Args: - message_store: Optional ChatMessageStoreProtocol to use for managing messages. - If not provided, a new ChatMessageStore will be created if needed. - **kwargs: Additional arguments for deserialization. - - Returns: - A new AgentThread instance with properties set from the serialized state. - """ - state = AgentThreadState.from_dict(serialized_thread_state) - - if state.service_thread_id is not None: - return cls(service_thread_id=state.service_thread_id) - - # If we don't have any ChatMessageStoreProtocol state return here. - if state.chat_message_store_state is None: - return cls() - - if message_store is not None: - try: - await message_store.add_messages(state.chat_message_store_state.messages, **kwargs) - except Exception as ex: - raise AgentThreadException("Failed to deserialize the provided message store.") from ex - return cls(message_store=message_store) - try: - message_store = ChatMessageStore(messages=state.chat_message_store_state.messages, **kwargs) - except Exception as ex: - raise AgentThreadException("Failed to deserialize the message store.") from ex - return cls(message_store=message_store) - - async def update_from_thread_state( - self, - serialized_thread_state: MutableMapping[str, Any], - **kwargs: Any, - ) -> None: - """Deserializes the state from a dictionary into the thread properties. - - Args: - serialized_thread_state: The serialized thread state as a dictionary. - - Keyword Args: - **kwargs: Additional arguments for deserialization. - """ - state = AgentThreadState.from_dict(serialized_thread_state) - - if state.service_thread_id is not None: - self.service_thread_id = state.service_thread_id - # Since we have an ID, we should not have a chat message store and we can return here. - return - # If we don't have any ChatMessageStoreProtocol state return here. - if state.chat_message_store_state is None: - return - if self.message_store is not None: - await self.message_store.add_messages(state.chat_message_store_state.messages, **kwargs) - # If we don't have a chat message store yet, create an in-memory one. - return - # Create the message store from the default. - self.message_store = ChatMessageStore(messages=state.chat_message_store_state.messages, **kwargs) diff --git a/python/packages/core/tests/core/test_memory.py b/python/packages/core/tests/core/test_memory.py deleted file mode 100644 index bd83933e54..0000000000 --- a/python/packages/core/tests/core/test_memory.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -import sys -from collections.abc import MutableSequence -from typing import Any - -from agent_framework import Message -from agent_framework._memory import Context, ContextProvider - - -class MockContextProvider(ContextProvider): - """Mock ContextProvider for testing.""" - - def __init__(self, messages: list[Message] | None = None) -> None: - self.context_messages = messages - self.thread_created_called = False - self.invoked_called = False - self.invoking_called = False - self.thread_created_thread_id = None - self.new_messages = None - self.model_invoking_messages = None - - async def thread_created(self, thread_id: str | None) -> None: - """Track thread_created calls.""" - self.thread_created_called = True - self.thread_created_thread_id = thread_id - - async def invoked( - self, - request_messages: Any, - response_messages: Any | None = None, - invoke_exception: Exception | None = None, - **kwargs: Any, - ) -> None: - """Track invoked calls.""" - self.invoked_called = True - self.new_messages = request_messages - - async def invoking(self, messages: Message | MutableSequence[Message], **kwargs: Any) -> Context: - """Track invoking calls and return context.""" - self.invoking_called = True - self.model_invoking_messages = messages - context = Context() - context.messages = self.context_messages - return context - - -class MinimalContextProvider(ContextProvider): - """Minimal ContextProvider that only implements the required abstract method. - - Used to test the base class default implementations of thread_created, - invoked, __aenter__, and __aexit__. - """ - - async def invoking(self, messages: Message | MutableSequence[Message], **kwargs: Any) -> Context: - """Return empty context.""" - return Context() - - -class TestContext: - """Tests for Context class.""" - - def test_context_default_values(self) -> None: - """Test Context has correct default values.""" - context = Context() - assert context.instructions is None - assert context.messages == [] - assert context.tools == [] - - def test_context_with_values(self) -> None: - """Test Context can be initialized with values.""" - messages = [Message(role="user", text="Test message")] - context = Context(instructions="Test instructions", messages=messages) - assert context.instructions == "Test instructions" - assert len(context.messages) == 1 - assert context.messages[0].text == "Test message" - - -class TestContextProvider: - """Tests for ContextProvider class.""" - - async def test_thread_created(self) -> None: - """Test thread_created is called.""" - provider = MockContextProvider() - await provider.thread_created("test-thread-id") - assert provider.thread_created_called - assert provider.thread_created_thread_id == "test-thread-id" - - async def test_invoked(self) -> None: - """Test invoked is called.""" - provider = MockContextProvider() - message = Message(role="user", text="Test message") - await provider.invoked(message) - assert provider.invoked_called - assert provider.new_messages == message - - async def test_invoking(self) -> None: - """Test invoking is called and returns context.""" - provider = MockContextProvider(messages=[Message(role="user", text="Context message")]) - message = Message(role="user", text="Test message") - context = await provider.invoking(message) - assert provider.invoking_called - assert provider.model_invoking_messages == message - assert context.messages is not None - assert len(context.messages) == 1 - assert context.messages[0].text == "Context message" - - async def test_base_thread_created_does_nothing(self) -> None: - """Test that base ContextProvider.thread_created does nothing by default.""" - provider = MinimalContextProvider() - await provider.thread_created("some-thread-id") - await provider.thread_created(None) - - async def test_base_invoked_does_nothing(self) -> None: - """Test that base ContextProvider.invoked does nothing by default.""" - provider = MinimalContextProvider() - message = Message(role="user", text="Test") - await provider.invoked(message) - await provider.invoked(message, response_messages=message) - await provider.invoked(message, invoke_exception=Exception("test")) - - async def test_base_aenter_returns_self(self) -> None: - """Test that base ContextProvider.__aenter__ returns self.""" - provider = MinimalContextProvider() - async with provider as p: - assert p is provider - - async def test_base_aexit_does_nothing(self) -> None: - """Test that base ContextProvider.__aexit__ handles exceptions gracefully.""" - provider = MinimalContextProvider() - await provider.__aexit__(None, None, None) - try: - raise ValueError("test error") - except ValueError: - exc_info = sys.exc_info() - await provider.__aexit__(exc_info[0], exc_info[1], exc_info[2]) diff --git a/python/packages/core/tests/core/test_threads.py b/python/packages/core/tests/core/test_threads.py deleted file mode 100644 index 5b3fc5ffd1..0000000000 --- a/python/packages/core/tests/core/test_threads.py +++ /dev/null @@ -1,600 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -from collections.abc import Sequence -from typing import Any - -import pytest - -from agent_framework import AgentThread, ChatMessageStore, Message -from agent_framework._threads import AgentThreadState, ChatMessageStoreState -from agent_framework.exceptions import AgentThreadException - - -class MockChatMessageStore: - """Mock implementation of ChatMessageStoreProtocol for testing.""" - - def __init__(self, messages: list[Message] | None = None) -> None: - self._messages = messages or [] - self._serialize_calls = 0 - self._deserialize_calls = 0 - - async def list_messages(self) -> list[Message]: - return self._messages - - async def add_messages(self, messages: Sequence[Message]) -> None: - self._messages.extend(messages) - - async def serialize(self, **kwargs: Any) -> Any: - self._serialize_calls += 1 - return {"messages": [msg.__dict__ for msg in self._messages], "kwargs": kwargs} - - async def update_from_state(self, serialized_store_state: Any, **kwargs: Any) -> None: - self._deserialize_calls += 1 - if serialized_store_state and "messages" in serialized_store_state: - self._messages = serialized_store_state["messages"] - - @classmethod - async def deserialize(cls, serialized_store_state: Any, **kwargs: Any) -> "MockChatMessageStore": - instance = cls() - await instance.update_from_state(serialized_store_state, **kwargs) - return instance - - -@pytest.fixture -def sample_messages() -> list[Message]: - """Fixture providing sample chat messages for testing.""" - return [ - Message(role="user", text="Hello", message_id="msg1"), - Message(role="assistant", text="Hi there!", message_id="msg2"), - Message(role="user", text="How are you?", message_id="msg3"), - ] - - -@pytest.fixture -def sample_message() -> Message: - """Fixture providing a single sample chat message for testing.""" - return Message(role="user", text="Test message", message_id="test1") - - -class TestAgentThread: - """Test cases for AgentThread class.""" - - def test_init_with_no_parameters(self) -> None: - """Test AgentThread initialization with no parameters.""" - thread = AgentThread() - assert thread.service_thread_id is None - assert thread.message_store is None - - def test_init_with_service_thread_id(self) -> None: - """Test AgentThread initialization with service_thread_id.""" - service_thread_id = "test-conversation-123" - thread = AgentThread(service_thread_id=service_thread_id) - assert thread.service_thread_id == service_thread_id - assert thread.message_store is None - - def test_init_with_message_store(self) -> None: - """Test AgentThread initialization with message_store.""" - store = ChatMessageStore() - thread = AgentThread(message_store=store) - assert thread.service_thread_id is None - assert thread.message_store is store - - def test_service_thread_id_property_setter(self) -> None: - """Test service_thread_id property setter.""" - thread = AgentThread() - service_thread_id = "test-conversation-456" - - thread.service_thread_id = service_thread_id - assert thread.service_thread_id == service_thread_id - - def test_service_thread_id_setter_with_existing_message_store_raises_error(self) -> None: - """Test that setting service_thread_id when message_store exists raises AgentThreadException.""" - store = ChatMessageStore() - thread = AgentThread(message_store=store) - - with pytest.raises(AgentThreadException, match="Only the service_thread_id or message_store may be set"): - thread.service_thread_id = "test-conversation-789" - - def test_service_thread_id_setter_with_none_values(self) -> None: - """Test service_thread_id setter with None values does nothing.""" - thread = AgentThread() - thread.service_thread_id = None # Should not raise error - assert thread.service_thread_id is None - - def test_message_store_property_setter(self) -> None: - """Test message_store property setter.""" - thread = AgentThread() - store = ChatMessageStore() - - thread.message_store = store - assert thread.message_store is store - - def test_message_store_setter_with_existing_service_thread_id_raises_error(self) -> None: - """Test that setting message_store when service_thread_id exists raises AgentThreadException.""" - service_thread_id = "test-conversation-999" - thread = AgentThread(service_thread_id=service_thread_id) - store = ChatMessageStore() - - with pytest.raises(AgentThreadException, match="Only the service_thread_id or message_store may be set"): - thread.message_store = store - - def test_message_store_setter_with_none_values(self) -> None: - """Test message_store setter with None values does nothing.""" - thread = AgentThread() - thread.message_store = None # Should not raise error - assert thread.message_store is None - - async def test_get_messages_with_message_store(self, sample_messages: list[Message]) -> None: - """Test get_messages when message_store is set.""" - store = ChatMessageStore(sample_messages) - thread = AgentThread(message_store=store) - - assert thread.message_store is not None - - messages: list[Message] = await thread.message_store.list_messages() - - assert messages is not None - assert len(messages) == 3 - assert messages[0].text == "Hello" - assert messages[1].text == "Hi there!" - assert messages[2].text == "How are you?" - - async def test_get_messages_with_no_message_store(self) -> None: - """Test get_messages when no message_store is set.""" - thread = AgentThread() - - assert thread.message_store is None - - async def test_on_new_messages_with_service_thread_id(self, sample_message: Message) -> None: - """Test _on_new_messages when service_thread_id is set (should do nothing).""" - thread = AgentThread(service_thread_id="test-conv") - - await thread.on_new_messages(sample_message) - - # Should not create a message store - assert thread.message_store is None - - async def test_on_new_messages_single_message_creates_store(self, sample_message: Message) -> None: - """Test _on_new_messages with single message creates ChatMessageStore.""" - thread = AgentThread() - - await thread.on_new_messages(sample_message) - - assert thread.message_store is not None - assert isinstance(thread.message_store, ChatMessageStore) - messages = await thread.message_store.list_messages() - assert len(messages) == 1 - assert messages[0].text == "Test message" - - async def test_on_new_messages_multiple_messages(self, sample_messages: list[Message]) -> None: - """Test _on_new_messages with multiple messages.""" - thread = AgentThread() - - await thread.on_new_messages(sample_messages) - - assert thread.message_store is not None - messages = await thread.message_store.list_messages() - assert len(messages) == 3 - - async def test_on_new_messages_with_existing_store(self, sample_message: Message) -> None: - """Test _on_new_messages adds to existing message store.""" - initial_messages = [Message(role="user", text="Initial", message_id="init1")] - store = ChatMessageStore(initial_messages) - thread = AgentThread(message_store=store) - - await thread.on_new_messages(sample_message) - - assert thread.message_store is not None - messages = await thread.message_store.list_messages() - assert len(messages) == 2 - assert messages[0].text == "Initial" - assert messages[1].text == "Test message" - - async def test_deserialize_with_service_thread_id(self) -> None: - """Test _deserialize with service_thread_id.""" - serialized_data = {"service_thread_id": "test-conv-123", "chat_message_store_state": None} - - thread = await AgentThread.deserialize(serialized_data) - - assert thread.service_thread_id == "test-conv-123" - assert thread.message_store is None - - async def test_deserialize_with_store_state(self, sample_messages: list[Message]) -> None: - """Test _deserialize with chat_message_store_state.""" - store_state = {"messages": sample_messages} - serialized_data = {"service_thread_id": None, "chat_message_store_state": store_state} - - thread = await AgentThread.deserialize(serialized_data) - - assert thread.service_thread_id is None - assert thread.message_store is not None - assert isinstance(thread.message_store, ChatMessageStore) - - async def test_deserialize_with_no_state(self) -> None: - """Test _deserialize with no state.""" - thread = AgentThread() - serialized_data = {"service_thread_id": None, "chat_message_store_state": None} - - await thread.deserialize(serialized_data) - - assert thread.service_thread_id is None - assert thread.message_store is None - - async def test_deserialize_with_existing_store(self) -> None: - """Test _deserialize with existing message store.""" - store = MockChatMessageStore() - thread = AgentThread(message_store=store) - serialized_data: dict[str, Any] = { - "service_thread_id": None, - "chat_message_store_state": {"messages": [Message(role="user", text="test")]}, - } - - await thread.update_from_thread_state(serialized_data) - - assert store._messages - assert store._messages[0].text == "test" - - async def test_serialize_with_service_thread_id(self) -> None: - """Test serialize with service_thread_id.""" - thread = AgentThread(service_thread_id="test-conv-456") - - result = await thread.serialize() - - assert result["service_thread_id"] == "test-conv-456" - assert result["chat_message_store_state"] is None - - async def test_serialize_with_message_store(self) -> None: - """Test serialize with message_store.""" - store = MockChatMessageStore() - thread = AgentThread(message_store=store) - - result = await thread.serialize() - - assert result["service_thread_id"] is None - assert result["chat_message_store_state"] is not None - assert store._serialize_calls == 1 # pyright: ignore[reportPrivateUsage] - - async def test_serialize_with_no_state(self) -> None: - """Test serialize with no state.""" - thread = AgentThread() - - result = await thread.serialize() - - assert result["service_thread_id"] is None - assert result["chat_message_store_state"] is None - - async def test_serialize_with_kwargs(self) -> None: - """Test serialize passes kwargs to message store.""" - store = MockChatMessageStore() - thread = AgentThread(message_store=store) - - await thread.serialize(custom_param="test_value") - - assert store._serialize_calls == 1 # pyright: ignore[reportPrivateUsage] - - async def test_serialize_round_trip_messages(self, sample_messages: list[Message]) -> None: - """Test a roundtrip of the serialization.""" - store = ChatMessageStore(sample_messages) - thread = AgentThread(message_store=store) - new_thread = await AgentThread.deserialize(await thread.serialize()) - assert new_thread.message_store is not None - new_messages = await new_thread.message_store.list_messages() - assert len(new_messages) == len(sample_messages) - assert {new.text for new in new_messages} == {orig.text for orig in sample_messages} - - async def test_serialize_round_trip_thread_id(self) -> None: - """Test a roundtrip of the serialization.""" - thread = AgentThread(service_thread_id="test-1234") - new_thread = await AgentThread.deserialize(await thread.serialize()) - assert new_thread.message_store is None - assert new_thread.service_thread_id == "test-1234" - - -class TestChatMessageList: - """Test cases for ChatMessageStore class.""" - - def test_init_empty(self) -> None: - """Test ChatMessageStore initialization with no messages.""" - store = ChatMessageStore() - assert len(store.messages) == 0 - - def test_init_with_messages(self, sample_messages: list[Message]) -> None: - """Test ChatMessageStore initialization with messages.""" - store = ChatMessageStore(sample_messages) - assert len(store.messages) == 3 - - async def test_add_messages(self, sample_messages: list[Message]) -> None: - """Test adding messages to the store.""" - store = ChatMessageStore() - - await store.add_messages(sample_messages) - - assert len(store.messages) == 3 - messages = await store.list_messages() - assert messages[0].text == "Hello" - - async def test_get_messages(self, sample_messages: list[Message]) -> None: - """Test getting messages from the store.""" - store = ChatMessageStore(sample_messages) - - messages = await store.list_messages() - - assert len(messages) == 3 - assert messages[0].message_id == "msg1" - - async def test_serialize_state(self, sample_messages: list[Message]) -> None: - """Test serializing store state.""" - store = ChatMessageStore(sample_messages) - - result = await store.serialize() - - assert "messages" in result - assert len(result["messages"]) == 3 - - async def test_serialize_state_empty(self) -> None: - """Test serializing empty store state.""" - store = ChatMessageStore() - - result = await store.serialize() - - assert "messages" in result - assert len(result["messages"]) == 0 - - async def test_deserialize_state(self, sample_messages: list[Message]) -> None: - """Test deserializing store state.""" - store = ChatMessageStore() - state_data = {"messages": sample_messages} - - await store.update_from_state(state_data) - - messages = await store.list_messages() - assert len(messages) == 3 - assert messages[0].text == "Hello" - - async def test_deserialize_state_none(self) -> None: - """Test deserializing None state.""" - store = ChatMessageStore() - - await store.update_from_state(None) - - assert len(store.messages) == 0 - - async def test_deserialize_state_empty(self) -> None: - """Test deserializing empty state.""" - store = ChatMessageStore() - - await store.update_from_state({}) - - assert len(store.messages) == 0 - - -class TestStoreState: - """Test cases for ChatMessageStoreState class.""" - - def test_init(self, sample_messages: list[Message]) -> None: - """Test ChatMessageStoreState initialization.""" - state = ChatMessageStoreState(messages=sample_messages) - - assert len(state.messages) == 3 - assert state.messages[0].text == "Hello" - - def test_init_empty(self) -> None: - """Test ChatMessageStoreState initialization with empty messages.""" - state = ChatMessageStoreState(messages=[]) - - assert len(state.messages) == 0 - - def test_init_none(self) -> None: - """Test ChatMessageStoreState initialization with None messages.""" - state = ChatMessageStoreState(messages=None) - - assert len(state.messages) == 0 - - def test_init_no_messages_arg(self) -> None: - """Test ChatMessageStoreState initialization without messages argument.""" - state = ChatMessageStoreState() - - assert len(state.messages) == 0 - - -class TestThreadState: - """Test cases for AgentThreadState class.""" - - def test_init_with_service_thread_id(self) -> None: - """Test AgentThreadState initialization with service_thread_id.""" - state = AgentThreadState(service_thread_id="test-conv-123") - - assert state.service_thread_id == "test-conv-123" - assert state.chat_message_store_state is None - - def test_init_with_chat_message_store_state(self) -> None: - """Test AgentThreadState initialization with chat_message_store_state.""" - store_data: dict[str, Any] = {"messages": []} - state = AgentThreadState.from_dict({"chat_message_store_state": store_data}) - - assert state.service_thread_id is None - assert state.chat_message_store_state.messages == [] - - def test_init_with_both(self) -> None: - """Test AgentThreadState initialization with both parameters.""" - store_data: dict[str, Any] = {"messages": []} - with pytest.raises(AgentThreadException): - AgentThreadState(service_thread_id="test-conv-123", chat_message_store_state=store_data) - - def test_init_defaults(self) -> None: - """Test AgentThreadState initialization with defaults.""" - state = AgentThreadState() - - assert state.service_thread_id is None - assert state.chat_message_store_state is None - - def test_init_with_chat_message_store_state_no_messages(self) -> None: - """Test AgentThreadState initialization with chat_message_store_state without messages field. - - This tests the scenario where a custom ChatMessageStore (like RedisChatMessageStore) - serializes its state without a 'messages' field, containing only configuration data - like thread_id, redis_url, etc. - """ - store_data: dict[str, Any] = { - "type": "redis_store_state", - "thread_id": "test_thread_123", - "redis_url": "redis://localhost:6379", - "key_prefix": "chat_messages", - } - state = AgentThreadState.from_dict({"chat_message_store_state": store_data}) - - assert state.service_thread_id is None - assert state.chat_message_store_state is not None - assert state.chat_message_store_state.messages == [] - - def test_init_with_chat_message_store_state_object(self) -> None: - """Test AgentThreadState initialization with ChatMessageStoreState object.""" - store_state = ChatMessageStoreState(messages=[Message(role="user", text="test")]) - state = AgentThreadState(chat_message_store_state=store_state) - - assert state.service_thread_id is None - assert state.chat_message_store_state is store_state - assert len(state.chat_message_store_state.messages) == 1 - - def test_init_with_invalid_chat_message_store_state_type(self) -> None: - """Test AgentThreadState initialization with invalid chat_message_store_state type.""" - with pytest.raises(TypeError, match="Could not parse ChatMessageStoreState"): - AgentThreadState(chat_message_store_state="invalid_type") # type: ignore[arg-type] - - -class TestChatMessageStoreStateEdgeCases: - """Additional edge case tests for ChatMessageStoreState.""" - - def test_init_with_invalid_messages_type(self) -> None: - """Test ChatMessageStoreState initialization with invalid messages type.""" - with pytest.raises(TypeError, match="Messages should be a list"): - ChatMessageStoreState(messages="invalid") # type: ignore[arg-type] - - def test_init_with_dict_messages(self) -> None: - """Test ChatMessageStoreState initialization with dict messages.""" - messages = [ - {"role": "user", "text": "Hello"}, - {"role": "assistant", "text": "Hi there!"}, - ] - state = ChatMessageStoreState(messages=messages) - - assert len(state.messages) == 2 - assert isinstance(state.messages[0], Message) - assert state.messages[0].text == "Hello" - - -class TestChatMessageStoreEdgeCases: - """Additional edge case tests for ChatMessageStore.""" - - async def test_deserialize_class_method(self) -> None: - """Test ChatMessageStore.deserialize class method.""" - serialized_data = { - "messages": [ - {"role": "user", "text": "Hello", "message_id": "msg1"}, - ] - } - - store = await ChatMessageStore.deserialize(serialized_data) - - assert isinstance(store, ChatMessageStore) - messages = await store.list_messages() - assert len(messages) == 1 - assert messages[0].text == "Hello" - - async def test_deserialize_empty_state(self) -> None: - """Test ChatMessageStore.deserialize with empty state.""" - serialized_data: dict[str, Any] = {"messages": []} - - store = await ChatMessageStore.deserialize(serialized_data) - - assert isinstance(store, ChatMessageStore) - messages = await store.list_messages() - assert len(messages) == 0 - - -class TestAgentThreadEdgeCases: - """Additional edge case tests for AgentThread.""" - - def test_is_initialized_with_service_thread_id(self) -> None: - """Test is_initialized property when service_thread_id is set.""" - thread = AgentThread(service_thread_id="test-123") - assert thread.is_initialized is True - - def test_is_initialized_with_message_store(self) -> None: - """Test is_initialized property when message_store is set.""" - store = ChatMessageStore() - thread = AgentThread(message_store=store) - assert thread.is_initialized is True - - def test_is_initialized_with_nothing(self) -> None: - """Test is_initialized property when nothing is set.""" - thread = AgentThread() - assert thread.is_initialized is False - - async def test_deserialize_with_custom_message_store(self) -> None: - """Test deserialize using a custom message store.""" - serialized_data = { - "service_thread_id": None, - "chat_message_store_state": { - "messages": [{"role": "user", "text": "Hello"}], - }, - } - custom_store = MockChatMessageStore() - - thread = await AgentThread.deserialize(serialized_data, message_store=custom_store) - - assert thread.message_store is custom_store - messages = await custom_store.list_messages() - assert len(messages) == 1 - - async def test_deserialize_with_failing_message_store_raises(self) -> None: - """Test deserialize raises AgentThreadException when message store fails.""" - - class FailingStore: - async def add_messages(self, messages: Sequence[Message], **kwargs: Any) -> None: - raise RuntimeError("Store failed") - - serialized_data = { - "service_thread_id": None, - "chat_message_store_state": { - "messages": [{"role": "user", "text": "Hello"}], - }, - } - failing_store = FailingStore() - - with pytest.raises(AgentThreadException, match="Failed to deserialize"): - await AgentThread.deserialize(serialized_data, message_store=failing_store) - - async def test_update_from_thread_state_with_service_thread_id(self) -> None: - """Test update_from_thread_state sets service_thread_id.""" - thread = AgentThread() - serialized_data = {"service_thread_id": "new-thread-id"} - - await thread.update_from_thread_state(serialized_data) - - assert thread.service_thread_id == "new-thread-id" - - async def test_update_from_thread_state_with_empty_chat_state(self) -> None: - """Test update_from_thread_state with empty chat_message_store_state.""" - thread = AgentThread() - serialized_data = {"service_thread_id": None, "chat_message_store_state": None} - - await thread.update_from_thread_state(serialized_data) - - assert thread.message_store is None - - async def test_update_from_thread_state_creates_message_store(self) -> None: - """Test update_from_thread_state creates message store if not existing.""" - thread = AgentThread() - serialized_data = { - "service_thread_id": None, - "chat_message_store_state": { - "messages": [{"role": "user", "text": "Hello"}], - }, - } - - await thread.update_from_thread_state(serialized_data) - - assert thread.message_store is not None - messages = await thread.message_store.list_messages() - assert len(messages) == 1 diff --git a/python/packages/devui/agent_framework_devui/_conversations.py b/python/packages/devui/agent_framework_devui/_conversations.py index 51f42a2389..79ec7c7034 100644 --- a/python/packages/devui/agent_framework_devui/_conversations.py +++ b/python/packages/devui/agent_framework_devui/_conversations.py @@ -3,7 +3,7 @@ """Conversation storage abstraction for OpenAI Conversations API. This module provides a clean abstraction layer for managing conversations -while wrapping AgentFramework's AgentThread underneath. +with in-memory message storage. """ from __future__ import annotations @@ -13,7 +13,7 @@ from abc import ABC, abstractmethod from typing import Any, Literal, cast -from agent_framework import AgentSession, AgentThread, Message +from agent_framework import AgentSession, Message from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage from openai.types.conversations import Conversation, ConversationDeletedResource from openai.types.conversations.conversation_item import ConversationItem @@ -38,14 +38,14 @@ class ConversationStore(ABC): """Abstract base class for conversation storage. Provides OpenAI Conversations API interface while managing - AgentThread instances underneath. + message storage internally. """ @abstractmethod def create_conversation( self, metadata: dict[str, str] | None = None, conversation_id: str | None = None ) -> Conversation: - """Create a new conversation (wraps AgentThread creation). + """Create a new conversation. Args: metadata: Optional metadata dict (e.g., {"agent_id": "weather_agent"}) @@ -86,7 +86,7 @@ def update_conversation(self, conversation_id: str, metadata: dict[str, str]) -> @abstractmethod def delete_conversation(self, conversation_id: str) -> ConversationDeletedResource: - """Delete conversation (including AgentThread). + """Delete conversation. Args: conversation_id: Conversation ID @@ -101,7 +101,7 @@ def delete_conversation(self, conversation_id: str) -> ConversationDeletedResour @abstractmethod async def add_items(self, conversation_id: str, items: list[dict[str, Any]]) -> list[ConversationItem]: - """Add items to conversation (syncs to AgentThread.message_store). + """Add items to conversation. Args: conversation_id: Conversation ID @@ -119,7 +119,7 @@ async def add_items(self, conversation_id: str, items: list[dict[str, Any]]) -> async def list_items( self, conversation_id: str, limit: int = 100, after: str | None = None, order: str = "asc" ) -> tuple[list[ConversationItem], bool]: - """List conversation items from AgentThread.message_store. + """List conversation items. Args: conversation_id: Conversation ID @@ -183,7 +183,7 @@ def add_trace(self, conversation_id: str, trace_event: dict[str, Any]) -> None: """Add a trace event to the conversation for context inspection. Traces capture execution metadata like token usage, timing, and LLM context - that isn't stored in the AgentThread but is useful for debugging. + that is useful for debugging. Args: conversation_id: Conversation ID @@ -205,17 +205,17 @@ def get_traces(self, conversation_id: str) -> list[dict[str, Any]]: class InMemoryConversationStore(ConversationStore): - """In-memory conversation storage wrapping AgentThread. + """In-memory conversation storage. This implementation stores conversations in memory with their - underlying AgentThread instances for execution. + underlying message lists and AgentSession instances for execution. """ def __init__(self) -> None: """Initialize in-memory conversation storage. Storage structure maps conversation IDs to conversation data including - the underlying AgentThread, metadata, and cached ConversationItems. + messages, metadata, and cached ConversationItems. """ self._conversations: dict[str, dict[str, Any]] = {} @@ -225,12 +225,12 @@ def __init__(self) -> None: def create_conversation( self, metadata: dict[str, str] | None = None, conversation_id: str | None = None ) -> Conversation: - """Create a new conversation with underlying AgentThread and checkpoint storage.""" + """Create a new conversation with message storage and checkpoint storage.""" conv_id = conversation_id or f"conv_{uuid.uuid4().hex}" created_at = int(time.time()) - # Create AgentThread for internal message storage and AgentSession for execution - thread = AgentThread() + # Create message list for internal storage and AgentSession for execution + messages: list[Message] = [] session = AgentSession(session_id=conv_id) # Create session-scoped checkpoint storage (one per conversation) @@ -238,9 +238,9 @@ def create_conversation( self._conversations[conv_id] = { "id": conv_id, - "thread": thread, + "messages": messages, "session": session, - "checkpoint_storage": checkpoint_storage, # Stored alongside thread + "checkpoint_storage": checkpoint_storage, "metadata": metadata or {}, "created_at": created_at, "items": [], @@ -281,7 +281,7 @@ def update_conversation(self, conversation_id: str, metadata: dict[str, str]) -> ) def delete_conversation(self, conversation_id: str) -> ConversationDeletedResource: - """Delete conversation and its AgentThread.""" + """Delete conversation.""" if conversation_id not in self._conversations: raise ValueError(f"Conversation {conversation_id} not found") @@ -292,14 +292,14 @@ def delete_conversation(self, conversation_id: str) -> ConversationDeletedResour return ConversationDeletedResource(id=conversation_id, object="conversation.deleted", deleted=True) async def add_items(self, conversation_id: str, items: list[dict[str, Any]]) -> list[ConversationItem]: - """Add items to conversation and sync to AgentThread.""" + """Add items to conversation.""" conv_data = self._conversations.get(conversation_id) if not conv_data: raise ValueError(f"Conversation {conversation_id} not found") - thread: AgentThread = conv_data["thread"] + stored_messages: list[Message] = conv_data["messages"] - # Convert items to ChatMessages and add to thread + # Convert items to Messages and add to storage chat_messages = [] for item in items: # Simple conversion - assume text content for now @@ -310,8 +310,8 @@ async def add_items(self, conversation_id: str, items: list[dict[str, Any]]) -> chat_msg = Message(role=role, text=text) # type: ignore[arg-type] chat_messages.append(chat_msg) - # Add messages to AgentThread - await thread.on_new_messages(chat_messages) + # Add messages to internal storage + stored_messages.extend(chat_messages) # Create Message objects (ConversationItem is a Union - use concrete Message type) conv_items: list[ConversationItem] = [] @@ -356,9 +356,9 @@ async def add_items(self, conversation_id: str, items: list[dict[str, Any]]) -> async def list_items( self, conversation_id: str, limit: int = 100, after: str | None = None, order: str = "asc" ) -> tuple[list[ConversationItem], bool]: - """List conversation items from AgentThread message store. + """List conversation items. - Converts AgentFramework ChatMessages to proper OpenAI ConversationItem types: + Converts stored Messages to proper OpenAI ConversationItem types: - Messages with text/images/files → Message - Function calls → ResponseFunctionToolCallItem - Function results → ResponseFunctionToolCallOutputItem @@ -367,119 +367,118 @@ async def list_items( if not conv_data: raise ValueError(f"Conversation {conversation_id} not found") - thread: AgentThread = conv_data["thread"] + stored_messages: list[Message] = conv_data["messages"] - # Get messages from thread's message store + # Convert stored messages to ConversationItem types items: list[ConversationItem] = [] - if thread.message_store: - af_messages = await thread.message_store.list_messages() - - # Convert each AgentFramework Message to appropriate ConversationItem type(s) - for i, msg in enumerate(af_messages): - item_id = f"item_{i}" - role_str = msg.role if hasattr(msg.role, "value") else str(msg.role) - role = cast(MessageRole, role_str) # Safe: Agent Framework roles match OpenAI roles - - # Process each content item in the message - # A single Message may produce multiple ConversationItems - # (e.g., a message with both text and a function call) - message_contents: list[TextContent | ResponseInputImage | ResponseInputFile] = [] - function_calls = [] - function_results = [] - - for content in msg.contents: - content_type = getattr(content, "type", None) - - if content_type == "text": - # Text content for Message - text_value = getattr(content, "text", "") - message_contents.append(TextContent(type="text", text=text_value)) - - elif content_type == "data": - # Data content (images, files, PDFs) - uri = getattr(content, "uri", "") - media_type = getattr(content, "media_type", None) - - if media_type and media_type.startswith("image/"): - # Convert to ResponseInputImage - message_contents.append( - ResponseInputImage(type="input_image", image_url=uri, detail="auto") - ) - else: - # Convert to ResponseInputFile - # Extract filename from URI if possible - filename = None - if media_type == "application/pdf": - filename = "document.pdf" - - message_contents.append( - ResponseInputFile(type="input_file", file_url=uri, filename=filename) - ) + af_messages = stored_messages - elif content_type == "function_call": - # Function call - create separate ConversationItem - call_id = getattr(content, "call_id", None) - name = getattr(content, "name", "") - arguments = getattr(content, "arguments", "") - - if call_id and name: - function_calls.append( - ResponseFunctionToolCallItem( - id=f"{item_id}_call_{call_id}", - call_id=call_id, - name=name, - arguments=arguments, - type="function_call", - status="completed", - ) - ) + # Convert each AgentFramework Message to appropriate ConversationItem type(s) + for i, msg in enumerate(af_messages): + item_id = f"item_{i}" + role_str = msg.role if hasattr(msg.role, "value") else str(msg.role) + role = cast(MessageRole, role_str) # Safe: Agent Framework roles match OpenAI roles - elif content_type == "function_result": - # Function result - create separate ConversationItem - call_id = getattr(content, "call_id", None) - # Output is stored in the 'result' field of FunctionResultContent - result_value = getattr(content, "result", None) - # Convert result to string (it could be dict, list, or other types) - if result_value is None: - output = "" - elif isinstance(result_value, str): - output = result_value - else: - import json - - try: - output = json.dumps(result_value) - except (TypeError, ValueError): - output = str(result_value) - - if call_id: - function_results.append( - ResponseFunctionToolCallOutputItem( - id=f"{item_id}_result_{call_id}", - call_id=call_id, - output=output, - type="function_call_output", - status="completed", - ) + # Process each content item in the message + # A single Message may produce multiple ConversationItems + # (e.g., a message with both text and a function call) + message_contents: list[TextContent | ResponseInputImage | ResponseInputFile] = [] + function_calls = [] + function_results = [] + + for content in msg.contents: + content_type = getattr(content, "type", None) + + if content_type == "text": + # Text content for Message + text_value = getattr(content, "text", "") + message_contents.append(TextContent(type="text", text=text_value)) + + elif content_type == "data": + # Data content (images, files, PDFs) + uri = getattr(content, "uri", "") + media_type = getattr(content, "media_type", None) + + if media_type and media_type.startswith("image/"): + # Convert to ResponseInputImage + message_contents.append( + ResponseInputImage(type="input_image", image_url=uri, detail="auto") + ) + else: + # Convert to ResponseInputFile + # Extract filename from URI if possible + filename = None + if media_type == "application/pdf": + filename = "document.pdf" + + message_contents.append( + ResponseInputFile(type="input_file", file_url=uri, filename=filename) + ) + + elif content_type == "function_call": + # Function call - create separate ConversationItem + call_id = getattr(content, "call_id", None) + name = getattr(content, "name", "") + arguments = getattr(content, "arguments", "") + + if call_id and name: + function_calls.append( + ResponseFunctionToolCallItem( + id=f"{item_id}_call_{call_id}", + call_id=call_id, + name=name, + arguments=arguments, + type="function_call", + status="completed", ) + ) + + elif content_type == "function_result": + # Function result - create separate ConversationItem + call_id = getattr(content, "call_id", None) + # Output is stored in the 'result' field of FunctionResultContent + result_value = getattr(content, "result", None) + # Convert result to string (it could be dict, list, or other types) + if result_value is None: + output = "" + elif isinstance(result_value, str): + output = result_value + else: + import json + + try: + output = json.dumps(result_value) + except (TypeError, ValueError): + output = str(result_value) + + if call_id: + function_results.append( + ResponseFunctionToolCallOutputItem( + id=f"{item_id}_result_{call_id}", + call_id=call_id, + output=output, + type="function_call_output", + status="completed", + ) + ) + + # Create ConversationItems based on what we found + # If message has text/images/files, create a Message item + if message_contents: + message = OpenAIMessage( + id=item_id, + type="message", + role=role, # type: ignore + content=message_contents, # type: ignore + status="completed", + ) + items.append(message) - # Create ConversationItems based on what we found - # If message has text/images/files, create a Message item - if message_contents: - message = OpenAIMessage( - id=item_id, - type="message", - role=role, # type: ignore - content=message_contents, # type: ignore - status="completed", - ) - items.append(message) - - # Add function call items - items.extend(function_calls) + # Add function call items + items.extend(function_calls) - # Add function result items - items.extend(function_results) + # Add function result items + items.extend(function_results) # Include checkpoints from checkpoint storage as conversation items checkpoint_storage = conv_data.get("checkpoint_storage") @@ -600,7 +599,7 @@ def add_trace(self, conversation_id: str, trace_event: dict[str, Any]) -> None: """Add a trace event to the conversation for context inspection. Traces capture execution metadata like token usage, timing, and LLM context - that isn't stored in the AgentThread but is useful for debugging. + that is useful for debugging. Args: conversation_id: Conversation ID diff --git a/python/packages/devui/tests/devui/test_conversations.py b/python/packages/devui/tests/devui/test_conversations.py index 812e0e718f..a9e7ac6441 100644 --- a/python/packages/devui/tests/devui/test_conversations.py +++ b/python/packages/devui/tests/devui/test_conversations.py @@ -199,21 +199,13 @@ async def test_list_items_pagination(): @pytest.mark.asyncio async def test_list_items_converts_function_calls(): """Test that list_items properly converts function calls to ResponseFunctionToolCallItem.""" - from agent_framework import ChatMessageStore, Message + from agent_framework import Message store = InMemoryConversationStore() # Create conversation conversation = store.create_conversation(metadata={"agent_id": "test_agent"}) - # Get the underlying thread for internal message store setup - thread = store._conversations[conversation.id]["thread"] - assert thread is not None - - # Initialize message store if not present - if thread.message_store is None: - thread.message_store = ChatMessageStore() - # Simulate messages from agent execution with function calls messages = [ Message(role="user", contents=[{"type": "text", "text": "What's the weather in SF?"}]), @@ -241,8 +233,8 @@ async def test_list_items_converts_function_calls(): Message(role="assistant", contents=[{"type": "text", "text": "The weather is sunny, 65°F"}]), ] - # Add messages to thread - await thread.on_new_messages(messages) + # Add messages to internal storage + store._conversations[conversation.id]["messages"].extend(messages) # List conversation items items, has_more = await store.list_items(conversation.id) @@ -284,20 +276,13 @@ async def test_list_items_converts_function_calls(): @pytest.mark.asyncio async def test_list_items_handles_images_and_files(): """Test that list_items properly converts data content (images/files) to OpenAI types.""" - from agent_framework import ChatMessageStore, Message + from agent_framework import Message store = InMemoryConversationStore() # Create conversation conversation = store.create_conversation(metadata={"agent_id": "test_agent"}) - # Get the underlying thread for internal message store setup - thread = store._conversations[conversation.id]["thread"] - assert thread is not None - - if thread.message_store is None: - thread.message_store = ChatMessageStore() - # Simulate message with image and file messages = [ Message( @@ -310,7 +295,8 @@ async def test_list_items_handles_images_and_files(): ), ] - await thread.on_new_messages(messages) + # Add messages to internal storage + store._conversations[conversation.id]["messages"].extend(messages) # List items items, has_more = await store.list_items(conversation.id) diff --git a/python/packages/mem0/agent_framework_mem0/_provider.py b/python/packages/mem0/agent_framework_mem0/_provider.py deleted file mode 100644 index d2ba0e7832..0000000000 --- a/python/packages/mem0/agent_framework_mem0/_provider.py +++ /dev/null @@ -1,239 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -from __future__ import annotations - -import sys -from collections.abc import MutableSequence, Sequence -from contextlib import AbstractAsyncContextManager -from typing import Any - -from agent_framework import Context, ContextProvider, Message -from agent_framework.exceptions import ServiceInitializationError -from mem0 import AsyncMemory, AsyncMemoryClient - -if sys.version_info >= (3, 12): - from typing import override # type: ignore # pragma: no cover -else: - from typing_extensions import override # type: ignore[import] # pragma: no cover - -if sys.version_info >= (3, 11): - from typing import NotRequired, Self, TypedDict # pragma: no cover -else: - from typing_extensions import NotRequired, Self, TypedDict # pragma: no cover - - -# Type aliases for Mem0 search response formats (v1.1 and v2; v1 is deprecated, but matches the type definition for v2) -class MemorySearchResponse_v1_1(TypedDict): - results: list[dict[str, Any]] - relations: NotRequired[list[dict[str, Any]]] - - -MemorySearchResponse_v2 = list[dict[str, Any]] - - -class Mem0Provider(ContextProvider): - """Mem0 Context Provider. - - Note: - Mem0's telemetry is disabled by default when using this package. - To enable telemetry, set the environment variable ``MEM0_TELEMETRY=true`` before - importing this package. - """ - - def __init__( - self, - mem0_client: AsyncMemory | AsyncMemoryClient | None = None, - api_key: str | None = None, - application_id: str | None = None, - agent_id: str | None = None, - thread_id: str | None = None, - user_id: str | None = None, - scope_to_per_operation_thread_id: bool = False, - context_prompt: str = ContextProvider.DEFAULT_CONTEXT_PROMPT, - ) -> None: - """Initializes a new instance of the Mem0Provider class. - - Args: - mem0_client: A pre-created Mem0 MemoryClient or None to create a default client. - api_key: The API key for authenticating with the Mem0 API. If not - provided, it will attempt to use the MEM0_API_KEY environment variable. - application_id: The application ID for scoping memories or None. - agent_id: The agent ID for scoping memories or None. - thread_id: The thread ID for scoping memories or None. - user_id: The user ID for scoping memories or None. - scope_to_per_operation_thread_id: Whether to scope memories to per-operation thread ID. - context_prompt: The prompt to prepend to retrieved memories. - """ - should_close_client = False - if mem0_client is None: - mem0_client = AsyncMemoryClient(api_key=api_key) - should_close_client = True - - self.api_key = api_key - self.application_id = application_id - self.agent_id = agent_id - self.thread_id = thread_id - self.user_id = user_id - self.scope_to_per_operation_thread_id = scope_to_per_operation_thread_id - self.context_prompt = context_prompt - self.mem0_client = mem0_client - self._per_operation_thread_id: str | None = None - self._should_close_client = should_close_client - - async def __aenter__(self) -> Self: - """Async context manager entry.""" - if self.mem0_client and isinstance(self.mem0_client, AbstractAsyncContextManager): - await self.mem0_client.__aenter__() - return self - - async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None: - """Async context manager exit.""" - if self._should_close_client and self.mem0_client and isinstance(self.mem0_client, AbstractAsyncContextManager): - await self.mem0_client.__aexit__(exc_type, exc_val, exc_tb) - - async def thread_created(self, thread_id: str | None = None) -> None: - """Called when a new thread is created. - - Args: - thread_id: The ID of the thread or None. - """ - self._validate_per_operation_thread_id(thread_id) - self._per_operation_thread_id = self._per_operation_thread_id or thread_id - - @override - async def invoked( - self, - request_messages: Message | Sequence[Message], - response_messages: Message | Sequence[Message] | None = None, - invoke_exception: Exception | None = None, - **kwargs: Any, - ) -> None: - self._validate_filters() - - request_messages_list = [request_messages] if isinstance(request_messages, Message) else list(request_messages) - response_messages_list = ( - [response_messages] - if isinstance(response_messages, Message) - else list(response_messages) - if response_messages - else [] - ) - messages_list = [*request_messages_list, *response_messages_list] - - # Extract role value - it may be a Role enum or a string - def get_role_value(role: Any) -> str: - return role.value if hasattr(role, "value") else str(role) - - messages: list[dict[str, str]] = [ - {"role": get_role_value(message.role), "content": message.text} - for message in messages_list - if get_role_value(message.role) in {"user", "assistant", "system"} and message.text and message.text.strip() - ] - - if messages: - await self.mem0_client.add( # type: ignore[misc] - messages=messages, - user_id=self.user_id, - agent_id=self.agent_id, - run_id=self._per_operation_thread_id if self.scope_to_per_operation_thread_id else self.thread_id, - metadata={"application_id": self.application_id}, - ) - - @override - async def invoking(self, messages: Message | MutableSequence[Message], **kwargs: Any) -> Context: - """Called before invoking the AI model to provide context. - - Args: - messages: List of new messages in the thread. - - Keyword Args: - **kwargs: not used at present. - - Returns: - Context: Context object containing instructions with memories. - """ - self._validate_filters() - messages_list = [messages] if isinstance(messages, Message) else list(messages) - input_text = "\n".join(msg.text for msg in messages_list if msg and msg.text and msg.text.strip()) - - # Validate input text is not empty before searching (possible for function approval responses) - if not input_text.strip(): - return Context(messages=None) - - # Build filters from init parameters - filters = self._build_filters() - - search_response: MemorySearchResponse_v1_1 | MemorySearchResponse_v2 = await self.mem0_client.search( # type: ignore[misc] - query=input_text, - filters=filters, - ) - - # Depending on the API version, the response schema varies slightly - if isinstance(search_response, list): - memories = search_response - elif isinstance(search_response, dict) and "results" in search_response: - memories = search_response["results"] - else: - # Fallback for unexpected schema - return response as text as-is - memories = [search_response] - - line_separated_memories = "\n".join(memory.get("memory", "") for memory in memories) - - return Context( - messages=[Message(role="user", text=f"{self.context_prompt}\n{line_separated_memories}")] - if line_separated_memories - else None - ) - - def _validate_filters(self) -> None: - """Validates that at least one filter is provided. - - Raises: - ServiceInitializationError: If no filters are provided. - """ - if not self.agent_id and not self.user_id and not self.application_id and not self.thread_id: - raise ServiceInitializationError( - "At least one of the filters: agent_id, user_id, application_id, or thread_id is required." - ) - - def _build_filters(self) -> dict[str, Any]: - """Build search filters from initialization parameters. - - Returns: - Filter dictionary for mem0 v2 search API containing initialization parameters. - In the v2 API, filters holds the user_id, agent_id, run_id (thread_id), and app_id - (application_id) which are required for scoping memory search operations. - """ - filters: dict[str, Any] = {} - - if self.user_id: - filters["user_id"] = self.user_id - if self.agent_id: - filters["agent_id"] = self.agent_id - if self.scope_to_per_operation_thread_id and self._per_operation_thread_id: - filters["run_id"] = self._per_operation_thread_id - elif self.thread_id: - filters["run_id"] = self.thread_id - if self.application_id: - filters["app_id"] = self.application_id - - return filters - - def _validate_per_operation_thread_id(self, thread_id: str | None) -> None: - """Validates that a new thread ID doesn't conflict with an existing one when scoped. - - Args: - thread_id: The new thread ID or None. - - Raises: - ValueError: If a new thread ID is provided when one already exists. - """ - if ( - self.scope_to_per_operation_thread_id - and thread_id - and self._per_operation_thread_id - and thread_id != self._per_operation_thread_id - ): - raise ValueError( - "Mem0Provider can only be used with one thread at a time when scope_to_per_operation_thread_id is True." - ) diff --git a/python/packages/orchestrations/tests/test_handoff.py b/python/packages/orchestrations/tests/test_handoff.py index fe43fe4387..c947b46524 100644 --- a/python/packages/orchestrations/tests/test_handoff.py +++ b/python/packages/orchestrations/tests/test_handoff.py @@ -10,7 +10,6 @@ ChatResponse, ChatResponseUpdate, Content, - Context, BaseContextProvider, Message, ResponseStream, diff --git a/python/packages/redis/agent_framework_redis/_chat_message_store.py b/python/packages/redis/agent_framework_redis/_chat_message_store.py deleted file mode 100644 index 5ace6c13af..0000000000 --- a/python/packages/redis/agent_framework_redis/_chat_message_store.py +++ /dev/null @@ -1,595 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -from __future__ import annotations - -from collections.abc import Sequence -from typing import Any -from uuid import uuid4 - -import redis.asyncio as redis -from agent_framework import Message -from agent_framework._serialization import SerializationMixin -from redis.credentials import CredentialProvider - - -class RedisStoreState(SerializationMixin): - """State model for serializing and deserializing Redis chat message store data.""" - - def __init__( - self, - thread_id: str, - redis_url: str | None = None, - key_prefix: str = "chat_messages", - max_messages: int | None = None, - ) -> None: - """State model for serializing and deserializing Redis chat message store data.""" - self.thread_id = thread_id - self.redis_url = redis_url - self.key_prefix = key_prefix - self.max_messages = max_messages - - -class RedisChatMessageStore: - """Redis-backed implementation of ChatMessageStoreProtocol using Redis Lists. - - This implementation provides persistent, thread-safe chat message storage using Redis Lists. - Messages are stored as JSON-serialized strings in chronological order, with each conversation - thread isolated by a unique Redis key. - - Key Features: - ============ - - **Persistent Storage**: Messages survive application restarts and crashes - - **Thread Isolation**: Each conversation thread has its own Redis key namespace - - **Auto Message Limits**: Configurable automatic trimming of old messages using LTRIM - - **Performance Optimized**: Uses native Redis operations for efficiency - - **State Serialization**: Full compatibility with Agent Framework thread serialization - - **Initial Message Support**: Pre-load conversations with existing message history - - **Production Ready**: Atomic operations, error handling, connection pooling - - Redis Operations: - - RPUSH: Add messages to the end of the list (chronological order) - - LRANGE: Retrieve messages in chronological order - - LTRIM: Maintain message limits by trimming old messages - - DELETE: Clear all messages for a thread - """ - - def __init__( - self, - redis_url: str | None = None, - credential_provider: CredentialProvider | None = None, - host: str | None = None, - port: int = 6380, - ssl: bool = True, - username: str | None = None, - thread_id: str | None = None, - key_prefix: str = "chat_messages", - max_messages: int | None = None, - messages: Sequence[Message] | None = None, - ) -> None: - """Initialize the Redis chat message store. - - Creates a Redis-backed chat message store for a specific conversation thread. - Supports both traditional URL-based authentication and Azure Managed Redis - with credential provider. - - Args: - redis_url: Redis connection URL (e.g., "redis://localhost:6379"). - Used for traditional authentication. Mutually exclusive with credential_provider. - credential_provider: Redis credential provider (redis.credentials.CredentialProvider) for - Azure AD authentication. Requires host parameter. Mutually exclusive with redis_url. - host: Redis host name (e.g., "myredis.redis.cache.windows.net"). - Required when using credential_provider. - port: Redis port number. Defaults to 6380 (Azure Redis SSL port). - ssl: Enable SSL/TLS connection. Defaults to True. - username: Redis username. Defaults to None. - thread_id: Unique identifier for this conversation thread. - If not provided, a UUID will be auto-generated. - This becomes part of the Redis key: {key_prefix}:{thread_id} - key_prefix: Prefix for Redis keys to namespace different applications. - Defaults to 'chat_messages'. Useful for multi-tenant scenarios. - max_messages: Maximum number of messages to retain in Redis. - When exceeded, oldest messages are automatically trimmed using LTRIM. - None means unlimited storage. - messages: Initial messages to pre-populate the conversation. - These are added to Redis on first access if the Redis key is empty. - Useful for resuming conversations or seeding with context. - - Raises: - ValueError: If neither redis_url nor credential_provider is provided. - ValueError: If both redis_url and credential_provider are provided. - ValueError: If credential_provider is used without host parameter. - - Examples: - Traditional connection: - store = RedisChatMessageStore( - redis_url="redis://localhost:6379", - thread_id="conversation_123" - ) - - Azure Managed Redis with credential provider: - from redis.credentials import CredentialProvider - from azure.identity.aio import DefaultAzureCredential - - store = RedisChatMessageStore( - credential_provider=CredentialProvider(DefaultAzureCredential()), - host="myredis.redis.cache.windows.net", - thread_id="conversation_123" - ) - """ - # Validate connection parameters - if redis_url is None and credential_provider is None: - raise ValueError("Either redis_url or credential_provider must be provided") - - if redis_url is not None and credential_provider is not None: - raise ValueError("redis_url and credential_provider are mutually exclusive") - - if credential_provider is not None and host is None: - raise ValueError("host is required when using credential_provider") - - # Store configuration - self.thread_id = thread_id or f"thread_{uuid4()}" - self.key_prefix = key_prefix - self.max_messages = max_messages - - # Initialize Redis client based on authentication method - if credential_provider is not None and host is not None: - # Azure AD authentication with credential provider - self.redis_url = None # Not using URL-based auth - self._redis_client = redis.Redis( - host=host, - port=port, - ssl=ssl, - username=username, - credential_provider=credential_provider, - decode_responses=True, - ) - else: - # Traditional URL-based authentication - self.redis_url = redis_url - self._redis_client = redis.from_url(redis_url, decode_responses=True) # type: ignore[no-untyped-call] - - # Handle initial messages (will be moved to Redis on first access) - self._initial_messages = list(messages) if messages else [] - self._initial_messages_added = False - - @property - def redis_key(self) -> str: - """Get the Redis key for this thread's messages. - - The key format is: {key_prefix}:{thread_id} - - Returns: - Redis key string used for storing this thread's messages. - - Example: - For key_prefix="chat_messages" and thread_id="user_123_session_456": - Returns "chat_messages:user_123_session_456" - """ - return f"{self.key_prefix}:{self.thread_id}" - - async def _ensure_initial_messages_added(self) -> None: - """Ensure initial messages are added to Redis if not already present. - - This method is called before any Redis operations to guarantee that - initial messages provided during construction are persisted to Redis. - """ - if not self._initial_messages or self._initial_messages_added: - return - - # Check if Redis key already has messages (prevents duplicate additions) - existing_count = await self._redis_client.llen(self.redis_key) # type: ignore[misc] # type: ignore[misc] - if existing_count == 0: - # Add initial messages using atomic pipeline operation - await self._add_redis_messages(self._initial_messages) - - # Mark as completed and free memory - self._initial_messages_added = True - self._initial_messages.clear() - - async def _add_redis_messages(self, messages: Sequence[Message]) -> None: - """Add multiple messages to Redis using atomic pipeline operation. - - This internal method efficiently adds multiple messages to the Redis list - using a single atomic transaction to ensure consistency. - - Args: - messages: Sequence of Message objects to add to Redis. - """ - if not messages: - return - - # Pre-serialize all messages for efficient pipeline operation - serialized_messages = [self._serialize_message(message) for message in messages] - - # Use Redis pipeline for atomic batch operation - async with self._redis_client.pipeline(transaction=True) as pipe: - for serialized_message in serialized_messages: - await pipe.rpush(self.redis_key, serialized_message) # type: ignore[misc] - await pipe.execute() - - async def add_messages(self, messages: Sequence[Message]) -> None: - """Add messages to the Redis store (ChatMessageStoreProtocol protocol method). - - This method implements the required ChatMessageStoreProtocol protocol for adding messages. - Messages are appended to the Redis list in chronological order, with automatic - trimming if message limits are configured. - - Args: - messages: Sequence of Message objects to add to the store. - Can be empty (no-op) or contain multiple messages. - - Thread Safety: - - Atomic pipeline ensures all messages are added together - - LTRIM operation is atomic for consistent message limits - - Example: - .. code-block:: python - - messages = [Message(role="user", text="Hello"), Message(role="assistant", text="Hi there!")] - await store.add_messages(messages) - """ - if not messages: - return - - # Ensure any initial messages are persisted first - await self._ensure_initial_messages_added() - - # Add new messages using atomic pipeline operation - await self._add_redis_messages(messages) - - # Apply message limit if configured (automatic cleanup) - if self.max_messages is not None: - current_count = await self._redis_client.llen(self.redis_key) # type: ignore[misc] - if current_count > self.max_messages: - # Keep only the most recent max_messages using LTRIM - await self._redis_client.ltrim(self.redis_key, -self.max_messages, -1) # type: ignore[misc] - - async def list_messages(self) -> list[Message]: - """Get all messages from the store in chronological order (ChatMessageStoreProtocol protocol method). - - This method implements the required ChatMessageStoreProtocol protocol for retrieving messages. - Returns all messages stored in Redis, ordered from oldest (index 0) to newest (index -1). - - Returns: - List of Message objects in chronological order (oldest first). - Returns empty list if no messages exist or if Redis connection fails. - - Example: - .. code-block:: python - - # Get all conversation history - messages = await store.list_messages() - """ - # Ensure any initial messages are persisted to Redis first - await self._ensure_initial_messages_added() - - messages = [] - # Retrieve all messages from Redis list (oldest to newest) - redis_messages = await self._redis_client.lrange(self.redis_key, 0, -1) # type: ignore[misc] - - if redis_messages: - for serialized_message in redis_messages: - # Deserialize each JSON message back to Message - message = self._deserialize_message(serialized_message) - messages.append(message) - - return messages - - async def serialize(self, **kwargs: Any) -> Any: - """Serialize the current store state for persistence (ChatMessageStoreProtocol protocol method). - - This method implements the required ChatMessageStoreProtocol protocol for state serialization. - Captures the Redis connection configuration and thread information needed to - reconstruct the store and reconnect to the same conversation data. - - Keyword Args: - **kwargs: Additional arguments passed to Pydantic model_dump() for serialization. - Common options: exclude_none=True, by_alias=True - - Returns: - Dictionary containing serialized store configuration that can be persisted - to databases, files, or other storage mechanisms. - """ - state = RedisStoreState( - thread_id=self.thread_id, - redis_url=self.redis_url, - key_prefix=self.key_prefix, - max_messages=self.max_messages, - ) - return state.to_dict(exclude_none=False, **kwargs) - - @classmethod - async def deserialize(cls, serialized_store_state: Any, **kwargs: Any) -> RedisChatMessageStore: - """Deserialize state data into a new store instance (ChatMessageStoreProtocol protocol method). - - This method implements the required ChatMessageStoreProtocol protocol for state deserialization. - Creates a new RedisChatMessageStore instance from previously serialized data, - allowing the store to reconnect to the same conversation data in Redis. - - Args: - serialized_store_state: Previously serialized state data from serialize_state(). - Should be a dictionary with thread_id, redis_url, etc. - - Keyword Args: - **kwargs: Additional arguments passed to Pydantic model validation. - - Returns: - A new RedisChatMessageStore instance configured from the serialized state. - - Raises: - ValueError: If required fields are missing or invalid in the serialized state. - """ - if not serialized_store_state: - raise ValueError("serialized_store_state is required for deserialization") - - # Validate and parse the serialized state using Pydantic - state = RedisStoreState.from_dict(serialized_store_state, **kwargs) - - # Create and return a new store instance with the deserialized configuration - return cls( - redis_url=state.redis_url, - thread_id=state.thread_id, - key_prefix=state.key_prefix, - max_messages=state.max_messages, - ) - - async def update_from_state(self, serialized_store_state: Any, **kwargs: Any) -> None: - """Deserialize state data into this store instance (ChatMessageStoreProtocol protocol method). - - This method implements the required ChatMessageStoreProtocol protocol for state deserialization. - Restores the store configuration from previously serialized data, allowing the store - to reconnect to the same conversation data in Redis. - - Args: - serialized_store_state: Previously serialized state data from serialize_state(). - Should be a dictionary with thread_id, redis_url, etc. - - Keyword Args: - **kwargs: Additional arguments passed to Pydantic model validation. - """ - if not serialized_store_state: - return - - # Validate and parse the serialized state using Pydantic - state = RedisStoreState.from_dict(serialized_store_state, **kwargs) - - # Update store configuration from deserialized state - self.thread_id = state.thread_id - if state.redis_url is not None: - self.redis_url = state.redis_url - self.key_prefix = state.key_prefix - self.max_messages = state.max_messages - - # Recreate Redis client if the URL changed - if state.redis_url and state.redis_url != getattr(self, "_last_redis_url", None): - self._redis_client = redis.from_url(state.redis_url, decode_responses=True) # type: ignore[no-untyped-call] - self._last_redis_url = state.redis_url - - # Reset initial message state since we're connecting to existing data - self._initial_messages_added = False - - async def clear(self) -> None: - """Remove all messages from the store. - - Permanently deletes all messages for this conversation thread by removing - the Redis key. This operation cannot be undone. - - Warning: - - This permanently deletes all conversation history - - Consider exporting messages before clearing if backup is needed - - Example: - .. code-block:: python - - # Clear conversation history - await store.clear() - - # Verify messages are gone - messages = await store.list_messages() - assert len(messages) == 0 - """ - await self._redis_client.delete(self.redis_key) - - def _serialize_message(self, message: Message) -> str: - """Serialize a Message to JSON string. - - Args: - message: Message to serialize. - - Returns: - JSON string representation of the message. - """ - # Serialize to compact JSON (no extra whitespace for Redis efficiency) - return message.to_json(separators=(",", ":")) - - def _deserialize_message(self, serialized_message: str) -> Message: - """Deserialize a JSON string to Message. - - Args: - serialized_message: JSON string representation of a message. - - Returns: - Message object. - """ - # Reconstruct Message using custom deserialization - return Message.from_json(serialized_message) - - # ============================================================================ - # List-like Convenience Methods (Redis-optimized async versions) - # ============================================================================ - - def __bool__(self) -> bool: - """Return True since the store always exists once created. - - This method is called by Python's truthiness checks (if store:). - Since a RedisChatMessageStore instance always represents a valid store, - this always returns True. - - Returns: - Always True - the store exists and is ready for operations. - - Note: - This is used by the Agent Framework to check if a message store - is configured: `if thread.message_store:` - """ - return True - - async def __len__(self) -> int: - """Return the number of messages in the Redis store. - - Provides efficient message counting using Redis LLEN command. - This is the async equivalent of Python's built-in len() function. - - Returns: - The count of messages currently stored in Redis. - """ - await self._ensure_initial_messages_added() - return await self._redis_client.llen(self.redis_key) # type: ignore[misc,no-any-return] - - async def getitem(self, index: int) -> Message: - """Get a message by index using Redis LINDEX. - - Args: - index: The index of the message to retrieve. - - Returns: - The Message at the specified index. - - Raises: - IndexError: If the index is out of range. - """ - await self._ensure_initial_messages_added() - - # Use Redis LINDEX for efficient single-item access - serialized_message = await self._redis_client.lindex(self.redis_key, index) # type: ignore[misc] - if serialized_message is None: - raise IndexError("list index out of range") - - return self._deserialize_message(serialized_message) - - async def setitem(self, index: int, item: Message) -> None: - """Set a message at the specified index using Redis LSET. - - Args: - index: The index at which to set the message. - item: The Message to set at the specified index. - - Raises: - IndexError: If the index is out of range. - """ - await self._ensure_initial_messages_added() - - # Validate index exists using LLEN - current_count = await self._redis_client.llen(self.redis_key) # type: ignore[misc] - if index < 0: - index = current_count + index - if index < 0 or index >= current_count: - raise IndexError("list index out of range") - - # Use Redis LSET for efficient single-item update - serialized_message = self._serialize_message(item) - await self._redis_client.lset(self.redis_key, index, serialized_message) # type: ignore[misc] - - async def append(self, item: Message) -> None: - """Append a message to the end of the store. - - Args: - item: The Message to append. - """ - await self.add_messages([item]) - - async def count(self) -> int: - """Return the number of messages in the Redis store. - - Returns: - The count of messages currently stored in Redis. - """ - await self._ensure_initial_messages_added() - return await self._redis_client.llen(self.redis_key) # type: ignore[misc,no-any-return] - - async def index(self, item: Message) -> int: - """Return the index of the first occurrence of the specified message. - - Uses Redis LINDEX to iterate through the list without loading all messages. - Still O(N) but more memory efficient for large lists. - - Args: - item: The Message to find. - - Returns: - The index of the first occurrence of the message. - - Raises: - ValueError: If the message is not found in the store. - """ - await self._ensure_initial_messages_added() - - target_serialized = self._serialize_message(item) - list_length = await self._redis_client.llen(self.redis_key) # type: ignore[misc] - - # Iterate through Redis list using LINDEX - for i in range(list_length): - redis_message = await self._redis_client.lindex(self.redis_key, i) # type: ignore[misc] - if redis_message == target_serialized: - return i - - raise ValueError("Message not found in store") - - async def remove(self, item: Message) -> None: - """Remove the first occurrence of the specified message from the store. - - Uses Redis LREM command for efficient removal by value. - O(N) but performed natively in Redis without data transfer. - - Args: - item: The Message to remove. - - Raises: - ValueError: If the message is not found in the store. - """ - await self._ensure_initial_messages_added() - - # Serialize the message to match Redis storage format - target_serialized = self._serialize_message(item) - - # Use LREM to remove first occurrence (count=1) - removed_count = await self._redis_client.lrem(self.redis_key, 1, target_serialized) # type: ignore[misc] - - if removed_count == 0: - raise ValueError("Message not found in store") - - async def extend(self, items: Sequence[Message]) -> None: - """Extend the store by appending all messages from the iterable. - - Args: - items: Sequence of Message objects to append. - """ - await self.add_messages(items) - - async def ping(self) -> bool: - """Test the Redis connection. - - Returns: - True if the connection is successful, False otherwise. - """ - try: - await self._redis_client.ping() # type: ignore[misc] - return True - except Exception: - return False - - async def aclose(self) -> None: - """Close the Redis connection. - - This method provides a clean way to close the underlying Redis connection - when the store is no longer needed. This is particularly useful in samples - and applications where explicit resource cleanup is desired. - """ - await self._redis_client.aclose() # type: ignore[misc] - - def __repr__(self) -> str: - """String representation of the store.""" - return ( - f"RedisChatMessageStore(thread_id='{self.thread_id}', " - f"redis_key='{self.redis_key}', max_messages={self.max_messages})" - ) diff --git a/python/packages/redis/agent_framework_redis/_provider.py b/python/packages/redis/agent_framework_redis/_provider.py deleted file mode 100644 index 193ea444d3..0000000000 --- a/python/packages/redis/agent_framework_redis/_provider.py +++ /dev/null @@ -1,595 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -from __future__ import annotations - -import json -import sys -from collections.abc import MutableSequence, Sequence -from functools import reduce -from operator import and_ -from typing import Any, Literal, cast - -import numpy as np -from agent_framework import Context, ContextProvider, Message -from agent_framework.exceptions import ( - AgentException, - ServiceInitializationError, - ServiceInvalidRequestError, -) -from redisvl.index import AsyncSearchIndex -from redisvl.query import FilterQuery, HybridQuery, TextQuery -from redisvl.query.filter import FilterExpression, Tag -from redisvl.utils.token_escaper import TokenEscaper -from redisvl.utils.vectorize import BaseVectorizer - -if sys.version_info >= (3, 11): - from typing import Self # pragma: no cover -else: - from typing_extensions import Self # pragma: no cover - -if sys.version_info >= (3, 12): - from typing import override # type: ignore # pragma: no cover -else: - from typing_extensions import override # type: ignore[import] # pragma: no cover - - -class RedisProvider(ContextProvider): - """Redis context provider with dynamic, filterable schema. - - Stores context in Redis and retrieves scoped context. - Uses full-text or optional hybrid vector search to ground model responses. - """ - - def __init__( - self, - redis_url: str = "redis://localhost:6379", - index_name: str = "context", - prefix: str = "context", - # Redis vectorizer configuration (optional, injected by client) - redis_vectorizer: BaseVectorizer | None = None, - vector_field_name: str | None = None, - vector_algorithm: Literal["flat", "hnsw"] | None = None, - vector_distance_metric: Literal["cosine", "ip", "l2"] | None = None, - # Partition fields (indexed for filtering) - application_id: str | None = None, - agent_id: str | None = None, - user_id: str | None = None, - thread_id: str | None = None, - scope_to_per_operation_thread_id: bool = False, - # Prompt and runtime - context_prompt: str = ContextProvider.DEFAULT_CONTEXT_PROMPT, - redis_index: Any = None, - overwrite_index: bool = False, - ): - """Create a Redis Context Provider. - - Args: - redis_url: The Redis server URL. - index_name: The name of the Redis index. - prefix: The prefix for all keys in the Redis database. - redis_vectorizer: The vectorizer to use for Redis. - vector_field_name: The name of the vector field in Redis. - vector_algorithm: The algorithm to use for vector search. - vector_distance_metric: The distance metric to use for vector search. - application_id: The application ID to scope the context. - agent_id: The agent ID to scope the context. - user_id: The user ID to scope the context. - thread_id: The thread ID to scope the context. - scope_to_per_operation_thread_id: Whether to scope to the per-operation thread ID. - context_prompt: The context prompt to use for the provider. - redis_index: The Redis index to use for the provider. - overwrite_index: Whether to overwrite the existing Redis index. - - """ - self.redis_url = redis_url - self.index_name = index_name - self.prefix = prefix - if redis_vectorizer is not None and not isinstance(redis_vectorizer, BaseVectorizer): - raise AgentException( - f"The redis vectorizer is not a valid type, got: {type(redis_vectorizer)}, expected: BaseVectorizer." - ) - self.redis_vectorizer = redis_vectorizer - self.vector_field_name = vector_field_name - self.vector_algorithm: Literal["flat", "hnsw"] | None = vector_algorithm - self.vector_distance_metric: Literal["cosine", "ip", "l2"] | None = vector_distance_metric - self.application_id = application_id - self.agent_id = agent_id - self.user_id = user_id - self.thread_id = thread_id - self.scope_to_per_operation_thread_id = scope_to_per_operation_thread_id - self.context_prompt = context_prompt - self.overwrite_index = overwrite_index - self._per_operation_thread_id: str | None = None - self._token_escaper: TokenEscaper = TokenEscaper() - self._conversation_id: str | None = None - self._index_initialized: bool = False - self._schema_dict: dict[str, Any] | None = None - self.redis_index = redis_index or AsyncSearchIndex.from_dict( - self.schema_dict, redis_url=self.redis_url, validate_on_load=True - ) - - @property - def schema_dict(self) -> dict[str, Any]: - """Get the Redis schema dictionary, computing and caching it on first access.""" - if self._schema_dict is None: - # Get vector configuration from vectorizer if available - vector_dims = self.redis_vectorizer.dims if self.redis_vectorizer is not None else None - vector_datatype = self.redis_vectorizer.dtype if self.redis_vectorizer is not None else None - - self._schema_dict = self._build_schema_dict( - index_name=self.index_name, - prefix=self.prefix, - vector_field_name=self.vector_field_name, - vector_dims=vector_dims, - vector_datatype=vector_datatype, - vector_algorithm=self.vector_algorithm, - vector_distance_metric=self.vector_distance_metric, - ) - return self._schema_dict - - def _build_filter_from_dict(self, filters: dict[str, str | None]) -> Any | None: - """Builds a combined filter expression from simple equality tags. - - This ANDs non-empty tag filters and is used to scope all operations to app/agent/user/thread partitions. - - Args: - filters: Mapping of field name to value; falsy values are ignored. - - Returns: - A combined filter expression or None if no filters are provided. - """ - parts = [Tag(k) == v for k, v in filters.items() if v] - return reduce(and_, parts) if parts else None - - def _build_schema_dict( - self, - *, - index_name: str, - prefix: str, - vector_field_name: str | None, - vector_dims: int | None, - vector_datatype: str | None, - vector_algorithm: Literal["flat", "hnsw"] | None, - vector_distance_metric: Literal["cosine", "ip", "l2"] | None, - ) -> dict[str, Any]: - """Builds the RediSearch schema configuration dictionary. - - Defines text and tag fields for messages plus an optional vector field enabling KNN/hybrid search. - - Keyword Args: - index_name: Index name. - prefix: Key prefix. - vector_field_name: Vector field name or None. - vector_dims: Vector dimensionality or None. - vector_datatype: Vector datatype or None. - vector_algorithm: Vector index algorithm or None. - vector_distance_metric: Vector distance metric or None. - - Returns: - Dict representing the index and fields configuration. - """ - fields: list[dict[str, Any]] = [ - {"name": "role", "type": "tag"}, - {"name": "mime_type", "type": "tag"}, - {"name": "content", "type": "text"}, - # Conversation tracking - {"name": "conversation_id", "type": "tag"}, - {"name": "message_id", "type": "tag"}, - {"name": "author_name", "type": "tag"}, - # Partition fields (TAG for fast filtering) - {"name": "application_id", "type": "tag"}, - {"name": "agent_id", "type": "tag"}, - {"name": "user_id", "type": "tag"}, - {"name": "thread_id", "type": "tag"}, - ] - - # Add vector field only if configured (keeps provider runnable with no params) - if vector_field_name is not None and vector_dims is not None: - fields.append({ - "name": vector_field_name, - "type": "vector", - "attrs": { - "algorithm": (vector_algorithm or "hnsw"), - "dims": int(vector_dims), - "distance_metric": (vector_distance_metric or "cosine"), - "datatype": (vector_datatype or "float32"), - }, - }) - - return { - "index": { - "name": index_name, - "prefix": prefix, - "key_separator": ":", - "storage_type": "hash", - }, - "fields": fields, - } - - async def _ensure_index(self) -> None: - """Initialize the search index. - - - Connect to existing index if it exists and schema matches - - Create new index if it doesn't exist - - Overwrite if requested via overwrite_index=True - - Validate schema compatibility to prevent accidental data loss - """ - if self._index_initialized: - return - - # Check if index already exists - index_exists = await self.redis_index.exists() - - if not self.overwrite_index and index_exists: - # Validate schema compatibility before connecting - await self._validate_schema_compatibility() - - # Create the index (will connect to existing or create new) - await self.redis_index.create(overwrite=self.overwrite_index, drop=False) - - self._index_initialized = True - - async def _validate_schema_compatibility(self) -> None: - """Validate that existing index schema matches current configuration. - - Raises ServiceInitializationError if schemas don't match, with helpful guidance. - - self._build_schema_dict returns a minimal schema while Redis returns an expanded - schema with all defaults filled in. To compare for incompatibilities, compare - significant parts of the schema by creating signatures with normalized default values. - """ - # Defaults for attr normalization - TAG_DEFAULTS = {"separator": ",", "case_sensitive": False, "withsuffixtrie": False} - TEXT_DEFAULTS = {"weight": 1.0, "no_stem": False} - - def _significant_index(i: dict[str, Any]) -> dict[str, Any]: - return {k: i.get(k) for k in ("name", "prefix", "key_separator", "storage_type")} - - def _sig_tag(attrs: dict[str, Any] | None) -> dict[str, Any]: - a = {**TAG_DEFAULTS, **(attrs or {})} - return {k: a[k] for k in ("separator", "case_sensitive", "withsuffixtrie")} - - def _sig_text(attrs: dict[str, Any] | None) -> dict[str, Any]: - a = {**TEXT_DEFAULTS, **(attrs or {})} - return {k: a[k] for k in ("weight", "no_stem")} - - def _sig_vector(attrs: dict[str, Any] | None) -> dict[str, Any]: - a = {**(attrs or {})} - # Require these to exist if vector field is present - return {k: a.get(k) for k in ("algorithm", "dims", "distance_metric", "datatype")} - - def _schema_signature(schema: dict[str, Any]) -> dict[str, Any]: - # Order-independent, minimal signature - sig: dict[str, Any] = {"index": _significant_index(schema.get("index", {})), "fields": {}} - for f in schema.get("fields", []): - name, ftype = f.get("name"), f.get("type") - if not name: - continue - if ftype == "tag": - sig["fields"][name] = {"type": "tag", "attrs": _sig_tag(f.get("attrs"))} - elif ftype == "text": - sig["fields"][name] = {"type": "text", "attrs": _sig_text(f.get("attrs"))} - elif ftype == "vector": - sig["fields"][name] = {"type": "vector", "attrs": _sig_vector(f.get("attrs"))} - else: - # Unknown field types: compare by type only - sig["fields"][name] = {"type": ftype} - return sig - - existing_index = await AsyncSearchIndex.from_existing(self.index_name, redis_url=self.redis_url) - existing_schema = existing_index.schema.to_dict() - current_schema = self.schema_dict - - existing_sig = _schema_signature(existing_schema) - current_sig = _schema_signature(current_schema) - - if existing_sig != current_sig: - # Add sigs to error message - raise ServiceInitializationError( - "Existing Redis index schema is incompatible with the current configuration.\n" - f"Existing (significant): {json.dumps(existing_sig, indent=2, sort_keys=True)}\n" - f"Current (significant): {json.dumps(current_sig, indent=2, sort_keys=True)}\n" - "Set overwrite_index=True to rebuild if this change is intentional." - ) - - async def _add( - self, - *, - data: dict[str, Any] | list[dict[str, Any]], - metadata: dict[str, Any] | None = None, - ) -> None: - """Inserts one or many documents with partition fields populated. - - Fills default partition fields, optionally embeds content when configured, and loads documents in a batch. - - Keyword Args: - data: Single document or list of documents to insert. - metadata: Optional metadata dictionary (unused placeholder). - - Raises: - ServiceInvalidRequestError: If required fields are missing or invalid. - """ - # Ensure provider has at least one scope set (symmetry with Mem0Provider) - self._validate_filters() - await self._ensure_index() - docs = data if isinstance(data, list) else [data] - - prepared: list[dict[str, Any]] = [] - for doc in docs: - d = dict(doc) # shallow copy - - # Partition defaults - d.setdefault("application_id", self.application_id) - d.setdefault("agent_id", self.agent_id) - d.setdefault("user_id", self.user_id) - d.setdefault("thread_id", self._effective_thread_id) - # Conversation defaults - d.setdefault("conversation_id", self._conversation_id) - - # Logical requirement - if "content" not in d: - raise ServiceInvalidRequestError("add() requires a 'content' field in data") - - # Vector field requirement (only if schema has one) - if self.vector_field_name: - d.setdefault(self.vector_field_name, None) - - prepared.append(d) - - # Batch embed contents for every message - if self.redis_vectorizer and self.vector_field_name: - text_list = [d["content"] for d in prepared] - embeddings = await self.redis_vectorizer.aembed_many(text_list, batch_size=len(text_list)) - for i, d in enumerate(prepared): - vec = np.asarray(embeddings[i], dtype=np.float32).tobytes() - field_name: str = self.vector_field_name - d[field_name] = vec - - # Load all at once if supported - await self.redis_index.load(prepared) - return - - async def _redis_search( - self, - text: str, - *, - text_scorer: str = "BM25STD", - filter_expression: Any | None = None, - return_fields: list[str] | None = None, - num_results: int = 10, - alpha: float = 0.7, - ) -> list[dict[str, Any]]: - """Runs a text or hybrid vector-text search with optional filters. - - Builds a TextQuery or HybridQuery and automatically ANDs partition filters to keep results scoped and safe. - - Args: - text: Query text. - - Keyword Args: - text_scorer: Scorer to use for text ranking. - filter_expression: Additional filter expression to AND with partition filters. - return_fields: Fields to return in results. - num_results: Maximum number of results. - alpha: Hybrid balancing parameter when vectors are enabled. - - Returns: - List of result dictionaries. - - Raises: - ServiceInvalidRequestError: If input is invalid or the query fails. - """ - # Enforce presence of at least one provider-level filter (symmetry with Mem0Provider) - await self._ensure_index() - self._validate_filters() - - q = (text or "").strip() - if not q: - raise ServiceInvalidRequestError("text_search() requires non-empty text") - num_results = max(int(num_results or 10), 1) - - combined_filter = self._build_filter_from_dict({ - "application_id": self.application_id, - "agent_id": self.agent_id, - "user_id": self.user_id, - "thread_id": self._effective_thread_id, - "conversation_id": self._conversation_id, - }) - - if filter_expression is not None: - combined_filter = (combined_filter & filter_expression) if combined_filter else filter_expression - - # Choose return fields - return_fields = ( - return_fields - if return_fields is not None - else ["content", "role", "application_id", "agent_id", "user_id", "thread_id"] - ) - - try: - if self.redis_vectorizer and self.vector_field_name: - # Build hybrid query: combine full-text and vector similarity - vector = await self.redis_vectorizer.aembed(q) - query = HybridQuery( - text=q, - text_field_name="content", - vector=vector, - vector_field_name=self.vector_field_name, - text_scorer=text_scorer, - filter_expression=combined_filter, - alpha=alpha, - dtype=self.redis_vectorizer.dtype, - num_results=num_results, - return_fields=return_fields, - stopwords=None, - ) - hybrid_results = await self.redis_index.query(query) - return cast(list[dict[str, Any]], hybrid_results) - # Text-only search - query = TextQuery( - text=q, - text_field_name="content", - text_scorer=text_scorer, - filter_expression=combined_filter, - num_results=num_results, - return_fields=return_fields, - stopwords=None, - ) - text_results = await self.redis_index.query(query) - return cast(list[dict[str, Any]], text_results) - except Exception as exc: # pragma: no cover - surface as framework error - raise ServiceInvalidRequestError(f"Redis text search failed: {exc}") from exc - - async def search_all(self, page_size: int = 200) -> list[dict[str, Any]]: - """Returns all documents in the index. - - Streams results via pagination to avoid excessive memory and response sizes. - - Args: - page_size: Page size used for pagination under the hood. - - Returns: - List of all documents. - """ - out: list[dict[str, Any]] = [] - async for batch in self.redis_index.paginate( - FilterQuery(FilterExpression("*"), return_fields=[], num_results=page_size), - page_size=page_size, - ): - out.extend(batch) - return out - - @property - def _effective_thread_id(self) -> str | None: - """Resolves the active thread id. - - Returns per-operation thread id when scoping is enabled; otherwise the provider's thread id. - """ - return self._per_operation_thread_id if self.scope_to_per_operation_thread_id else self.thread_id - - @override - async def thread_created(self, thread_id: str | None) -> None: - """Called when a new thread is created. - - Captures the per-operation thread id when scoping is enabled to enforce single-thread usage. - - Args: - thread_id: The ID of the thread or None. - """ - self._validate_per_operation_thread_id(thread_id) - self._per_operation_thread_id = self._per_operation_thread_id or thread_id - # Track current conversation id (Agent passes conversation_id here) - self._conversation_id = thread_id or self._conversation_id - - @override - async def invoked( - self, - request_messages: Message | Sequence[Message], - response_messages: Message | Sequence[Message] | None = None, - invoke_exception: Exception | None = None, - **kwargs: Any, - ) -> None: - self._validate_filters() - - request_messages_list = [request_messages] if isinstance(request_messages, Message) else list(request_messages) - response_messages_list = ( - [response_messages] - if isinstance(response_messages, Message) - else list(response_messages) - if response_messages - else [] - ) - messages_list = [*request_messages_list, *response_messages_list] - - messages: list[dict[str, Any]] = [] - for message in messages_list: - if message.role in {"user", "assistant", "system"} and message.text and message.text.strip(): - shaped: dict[str, Any] = { - "role": message.role, - "content": message.text, - "conversation_id": self._conversation_id, - "message_id": message.message_id, - "author_name": message.author_name, - } - messages.append(shaped) - if messages: - await self._add(data=messages) - - @override - async def invoking(self, messages: Message | MutableSequence[Message], **kwargs: Any) -> Context: - """Called before invoking the model to provide scoped context. - - Concatenates recent messages into a query, fetches matching memories from Redis. - Prepends them as instructions. - - Args: - messages: List of new messages in the thread. - - Keyword Args: - **kwargs: not used at present at present. - - Returns: - Context: Context object containing instructions with memories. - """ - self._validate_filters() - messages_list = [messages] if isinstance(messages, Message) else list(messages) - input_text = "\n".join(msg.text for msg in messages_list if msg and msg.text and msg.text.strip()) - - memories = await self._redis_search(text=input_text) - line_separated_memories = "\n".join( - str(memory.get("content", "")) for memory in memories if memory.get("content") - ) - - return Context( - messages=[Message(role="user", text=f"{self.context_prompt}\n{line_separated_memories}")] - if line_separated_memories - else None - ) - - async def __aenter__(self) -> Self: - """Async context manager entry. - - No special setup is required; provided for symmetry with the Mem0 provider. - """ - return self - - async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None: - """Async context manager exit. - - No cleanup is required; indexes/keys remain unless explicitly cleared. - """ - return - - def _validate_filters(self) -> None: - """Validates that at least one filter is provided. - - Prevents unbounded operations by requiring a partition filter before reads or writes. - - Raises: - ServiceInitializationError: If no filters are provided. - """ - if not self.agent_id and not self.user_id and not self.application_id and not self.thread_id: - raise ServiceInitializationError( - "At least one of the filters: agent_id, user_id, application_id, or thread_id is required." - ) - - def _validate_per_operation_thread_id(self, thread_id: str | None) -> None: - """Validates that a new thread ID doesn't conflict when scoped. - - Prevents cross-thread data leakage by enforcing single-thread usage when per-operation scoping is enabled. - - Args: - thread_id: The new thread ID or None. - - Raises: - ValueError: If a new thread ID conflicts with the existing one. - """ - if ( - self.scope_to_per_operation_thread_id - and thread_id - and self._per_operation_thread_id - and thread_id != self._per_operation_thread_id - ): - raise ValueError( - "RedisProvider can only be used with one thread, when scope_to_per_operation_thread_id is True." - ) From cf4cb173329377c0ed289d96aa83ddab6eedc8b2 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Feb 2026 13:44:52 +0100 Subject: [PATCH 08/28] rename: remove _new_ prefix from test files --- ...r.py => test_aisearch_context_provider.py} | 0 .../mem0/tests/test_mem0_context_provider.py | 712 ++++++------------ .../tests/test_mem0_new_context_provider.py | 352 --------- ...est_new_providers.py => test_providers.py} | 0 4 files changed, 251 insertions(+), 813 deletions(-) rename python/packages/azure-ai-search/tests/{test_aisearch_new_context_provider.py => test_aisearch_context_provider.py} (100%) delete mode 100644 python/packages/mem0/tests/test_mem0_new_context_provider.py rename python/packages/redis/tests/{test_new_providers.py => test_providers.py} (100%) diff --git a/python/packages/azure-ai-search/tests/test_aisearch_new_context_provider.py b/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py similarity index 100% rename from python/packages/azure-ai-search/tests/test_aisearch_new_context_provider.py rename to python/packages/azure-ai-search/tests/test_aisearch_context_provider.py diff --git a/python/packages/mem0/tests/test_mem0_context_provider.py b/python/packages/mem0/tests/test_mem0_context_provider.py index 129f1bfa61..96a70c2beb 100644 --- a/python/packages/mem0/tests/test_mem0_context_provider.py +++ b/python/packages/mem0/tests/test_mem0_context_provider.py @@ -1,22 +1,16 @@ # Copyright (c) Microsoft. All rights reserved. # pyright: reportPrivateUsage=false -import importlib -import os -import sys -from typing import Any -from unittest.mock import AsyncMock +from __future__ import annotations + +from unittest.mock import AsyncMock, patch import pytest from agent_framework import AgentResponse, Message from agent_framework._sessions import AgentSession, SessionContext from agent_framework.exceptions import ServiceInitializationError -from agent_framework.mem0 import Mem0ContextProvider - -def test_mem0_context_provider_import() -> None: - """Test that Mem0ContextProvider can be imported.""" - assert Mem0ContextProvider is not None +from agent_framework_mem0._context_provider import Mem0ContextProvider @pytest.fixture @@ -29,534 +23,330 @@ def mock_mem0_client() -> AsyncMock: mock_client.search = AsyncMock() mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock() - mock_client.async_client = AsyncMock() - mock_client.async_client.aclose = AsyncMock() return mock_client -@pytest.fixture -def mock_agent() -> AsyncMock: - """Create a mock agent.""" - return AsyncMock() - - -@pytest.fixture -def session() -> AgentSession: - """Create a test AgentSession.""" - return AgentSession(session_id="test-session") - - -@pytest.fixture -def sample_messages() -> list[Message]: - """Create sample chat messages for testing.""" - return [ - Message(role="user", text="Hello, how are you?"), - Message(role="assistant", text="I'm doing well, thank you!"), - Message(role="system", text="You are a helpful assistant"), - ] - - -def _make_context(input_messages: list[Message], session_id: str = "test-session") -> SessionContext: - """Helper to create a SessionContext with the given input messages.""" - return SessionContext(session_id=session_id, input_messages=input_messages) - - -def _empty_state() -> dict[str, Any]: - """Helper to create an empty state dict.""" - return {} - - -def test_init_with_all_ids(mock_mem0_client: AsyncMock) -> None: - """Test initialization with all IDs provided.""" - provider = Mem0ContextProvider( - source_id="mem0", - user_id="user123", - agent_id="agent123", - application_id="app123", - mem0_client=mock_mem0_client, - ) - assert provider.user_id == "user123" - assert provider.agent_id == "agent123" - assert provider.application_id == "app123" - - -def test_init_without_filters_succeeds(mock_mem0_client: AsyncMock) -> None: - """Test that initialization succeeds even without filters (validation happens during invocation).""" - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) - assert provider.user_id is None - assert provider.agent_id is None - assert provider.application_id is None - - -def test_init_with_custom_context_prompt(mock_mem0_client: AsyncMock) -> None: - """Test initialization with custom context prompt.""" - custom_prompt = "## Custom Memories\nConsider these memories:" - provider = Mem0ContextProvider( - source_id="mem0", user_id="user123", context_prompt=custom_prompt, mem0_client=mock_mem0_client - ) - assert provider.context_prompt == custom_prompt +# -- Initialization tests ------------------------------------------------------ -def test_init_with_provided_client_should_not_close(mock_mem0_client: AsyncMock) -> None: - """Test that provided client should not be closed by provider.""" - provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) - assert provider._should_close_client is False +class TestInit: + """Test Mem0ContextProvider initialization.""" - -async def test_async_context_manager_entry(mock_mem0_client: AsyncMock) -> None: - """Test async context manager entry returns self.""" - provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) - async with provider as ctx: - assert ctx is provider - - -async def test_async_context_manager_exit_does_not_close_provided_client(mock_mem0_client: AsyncMock) -> None: - """Test that async context manager does not close provided client.""" - provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) - assert provider._should_close_client is False - - async with provider: - pass - - mock_mem0_client.__aexit__.assert_not_called() - - -class TestMem0ContextProviderAfterRun: - """Test after_run method (storing messages to Mem0).""" - - async def test_after_run_fails_without_filters( - self, mock_mem0_client: AsyncMock, mock_agent: AsyncMock, session: AgentSession - ) -> None: - """Test that after_run fails when no filters are provided.""" - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) - ctx = _make_context([Message(role="user", text="Hello!")]) - - with pytest.raises(ServiceInitializationError) as exc_info: - await provider.after_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) - - assert "At least one of the filters" in str(exc_info.value) - - async def test_after_run_single_input_message( - self, mock_mem0_client: AsyncMock, mock_agent: AsyncMock, session: AgentSession - ) -> None: - """Test storing a single input message.""" - provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) - ctx = _make_context([Message(role="user", text="Hello!")]) - - await provider.after_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) - - mock_mem0_client.add.assert_called_once() - call_args = mock_mem0_client.add.call_args - assert call_args.kwargs["messages"] == [{"role": "user", "content": "Hello!"}] - assert call_args.kwargs["user_id"] == "user123" - - async def test_after_run_multiple_messages( - self, - mock_mem0_client: AsyncMock, - mock_agent: AsyncMock, - session: AgentSession, - sample_messages: list[Message], - ) -> None: - """Test storing multiple input messages.""" - provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) - ctx = _make_context(sample_messages) - - await provider.after_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) - - mock_mem0_client.add.assert_called_once() - call_args = mock_mem0_client.add.call_args - expected_messages = [ - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I'm doing well, thank you!"}, - {"role": "system", "content": "You are a helpful assistant"}, - ] - assert call_args.kwargs["messages"] == expected_messages - - async def test_after_run_includes_response_messages( - self, mock_mem0_client: AsyncMock, mock_agent: AsyncMock, session: AgentSession - ) -> None: - """Test that after_run includes response messages.""" - provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) - ctx = _make_context([Message(role="user", text="Hello!")]) - ctx._response = AgentResponse(messages=[Message(role="assistant", text="Hi there!")]) - - await provider.after_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) - - mock_mem0_client.add.assert_called_once() - call_args = mock_mem0_client.add.call_args - expected_messages = [ - {"role": "user", "content": "Hello!"}, - {"role": "assistant", "content": "Hi there!"}, - ] - assert call_args.kwargs["messages"] == expected_messages - - async def test_after_run_with_agent_id( - self, - mock_mem0_client: AsyncMock, - mock_agent: AsyncMock, - session: AgentSession, - sample_messages: list[Message], - ) -> None: - """Test storing messages with agent_id.""" - provider = Mem0ContextProvider(source_id="mem0", agent_id="agent123", mem0_client=mock_mem0_client) - ctx = _make_context(sample_messages) - - await provider.after_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) - - call_args = mock_mem0_client.add.call_args - assert call_args.kwargs["agent_id"] == "agent123" - assert call_args.kwargs["user_id"] is None - - async def test_after_run_with_application_id( - self, - mock_mem0_client: AsyncMock, - mock_agent: AsyncMock, - session: AgentSession, - sample_messages: list[Message], - ) -> None: - """Test storing messages with application_id in metadata.""" + def test_init_with_all_params(self, mock_mem0_client: AsyncMock) -> None: provider = Mem0ContextProvider( - source_id="mem0", user_id="user123", application_id="app123", mem0_client=mock_mem0_client + source_id="mem0", + mem0_client=mock_mem0_client, + api_key="key-123", + application_id="app1", + agent_id="agent1", + user_id="user1", + context_prompt="Custom prompt", ) - ctx = _make_context(sample_messages) - - await provider.after_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) - - call_args = mock_mem0_client.add.call_args - assert call_args.kwargs["metadata"] == {"application_id": "app123"} - - async def test_after_run_uses_session_id_as_run_id( - self, - mock_mem0_client: AsyncMock, - mock_agent: AsyncMock, - session: AgentSession, - sample_messages: list[Message], - ) -> None: - """Test that after_run uses the context session_id as run_id.""" - provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) - ctx = _make_context(sample_messages, session_id="my-session") - - await provider.after_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) - - call_args = mock_mem0_client.add.call_args - assert call_args.kwargs["run_id"] == "my-session" - - async def test_after_run_filters_empty_messages( - self, mock_mem0_client: AsyncMock, mock_agent: AsyncMock, session: AgentSession - ) -> None: - """Test that empty or invalid messages are filtered out.""" - provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) - messages = [ - Message(role="user", text=""), - Message(role="user", text=" "), - Message(role="user", text="Valid message"), + assert provider.source_id == "mem0" + assert provider.api_key == "key-123" + assert provider.application_id == "app1" + assert provider.agent_id == "agent1" + assert provider.user_id == "user1" + assert provider.context_prompt == "Custom prompt" + assert provider.mem0_client is mock_mem0_client + assert provider._should_close_client is False + + def test_init_default_context_prompt(self, mock_mem0_client: AsyncMock) -> None: + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + assert provider.context_prompt == Mem0ContextProvider.DEFAULT_CONTEXT_PROMPT + + def test_init_auto_creates_client_when_none(self) -> None: + """When no client is provided, a default AsyncMemoryClient is created and flagged for closing.""" + with ( + patch("mem0.client.main.AsyncMemoryClient.__init__", return_value=None) as mock_init, + patch("mem0.client.main.AsyncMemoryClient._validate_api_key", return_value=None), + ): + provider = Mem0ContextProvider(source_id="mem0", api_key="test-key", user_id="u1") + mock_init.assert_called_once_with(api_key="test-key") + assert provider._should_close_client is True + + def test_provided_client_not_flagged_for_close(self, mock_mem0_client: AsyncMock) -> None: + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + assert provider._should_close_client is False + + +# -- before_run tests ---------------------------------------------------------- + + +class TestBeforeRun: + """Test before_run hook.""" + + async def test_memories_added_to_context(self, mock_mem0_client: AsyncMock) -> None: + """Mocked mem0 search returns memories → messages added to context with prompt.""" + mock_mem0_client.search.return_value = [ + {"memory": "User likes Python"}, + {"memory": "User prefers dark mode"}, ] - ctx = _make_context(messages) + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[Message(role="user", text="Hello")], session_id="s1") - await provider.after_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) + await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] - call_args = mock_mem0_client.add.call_args - assert call_args.kwargs["messages"] == [{"role": "user", "content": "Valid message"}] + mock_mem0_client.search.assert_awaited_once() + assert "mem0" in ctx.context_messages + added = ctx.context_messages["mem0"] + assert len(added) == 1 + assert "User likes Python" in added[0].text # type: ignore[operator] + assert "User prefers dark mode" in added[0].text # type: ignore[operator] + assert provider.context_prompt in added[0].text # type: ignore[operator] - async def test_after_run_skips_when_no_valid_messages( - self, mock_mem0_client: AsyncMock, mock_agent: AsyncMock, session: AgentSession - ) -> None: - """Test that mem0 client is not called when no valid messages exist.""" - provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) - messages = [ - Message(role="user", text=""), - Message(role="user", text=" "), - ] - ctx = _make_context(messages) + async def test_empty_input_skips_search(self, mock_mem0_client: AsyncMock) -> None: + """Empty input messages → no search performed.""" + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[Message(role="user", text="")], session_id="s1") - await provider.after_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) + await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] - mock_mem0_client.add.assert_not_called() + mock_mem0_client.search.assert_not_awaited() + assert "mem0" not in ctx.context_messages + async def test_empty_search_results_no_messages(self, mock_mem0_client: AsyncMock) -> None: + """Empty search results → no messages added.""" + mock_mem0_client.search.return_value = [] + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[Message(role="user", text="test")], session_id="s1") -class TestMem0ContextProviderBeforeRun: - """Test before_run method (searching memories and adding to context).""" - - async def test_before_run_fails_without_filters( - self, mock_mem0_client: AsyncMock, mock_agent: AsyncMock, session: AgentSession - ) -> None: - """Test that before_run fails when no filters are provided.""" - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) - ctx = _make_context([Message(role="user", text="What's the weather?")]) - - with pytest.raises(ServiceInitializationError) as exc_info: - await provider.before_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) + await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] - assert "At least one of the filters" in str(exc_info.value) + assert "mem0" not in ctx.context_messages - async def test_before_run_single_message( - self, mock_mem0_client: AsyncMock, mock_agent: AsyncMock, session: AgentSession - ) -> None: - """Test before_run with a single input message.""" - provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) - ctx = _make_context([Message(role="user", text="What's the weather?")]) + async def test_validates_filters_before_search(self, mock_mem0_client: AsyncMock) -> None: + """Raises ServiceInitializationError when no filters.""" + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[Message(role="user", text="test")], session_id="s1") - mock_mem0_client.search.return_value = [ - {"memory": "User likes outdoor activities"}, - {"memory": "User lives in Seattle"}, - ] + with pytest.raises(ServiceInitializationError, match="At least one of the filters"): + await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] - await provider.before_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) + async def test_v1_1_response_format(self, mock_mem0_client: AsyncMock) -> None: + """Search response in v1.1 dict format with 'results' key.""" + mock_mem0_client.search.return_value = {"results": [{"memory": "remembered fact"}]} + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[Message(role="user", text="test")], session_id="s1") - mock_mem0_client.search.assert_called_once() - call_args = mock_mem0_client.search.call_args - assert call_args.kwargs["query"] == "What's the weather?" - assert call_args.kwargs["filters"] == {"user_id": "user123", "run_id": "test-session"} + await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] - context_messages = ctx.get_messages() - assert len(context_messages) > 0 - expected_text = ( - "## Memories\nConsider the following memories when answering user questions:\n" - "User likes outdoor activities\nUser lives in Seattle" - ) - assert context_messages[0].text == expected_text - - async def test_before_run_multiple_messages( - self, - mock_mem0_client: AsyncMock, - mock_agent: AsyncMock, - session: AgentSession, - sample_messages: list[Message], - ) -> None: - """Test before_run with multiple input messages.""" - provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) - ctx = _make_context(sample_messages) - - mock_mem0_client.search.return_value = [{"memory": "Previous conversation context"}] - - await provider.before_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) - - call_args = mock_mem0_client.search.call_args - expected_query = "Hello, how are you?\nI'm doing well, thank you!\nYou are a helpful assistant" - assert call_args.kwargs["query"] == expected_query - - async def test_before_run_with_agent_id( - self, mock_mem0_client: AsyncMock, mock_agent: AsyncMock, session: AgentSession - ) -> None: - """Test before_run with agent_id.""" - provider = Mem0ContextProvider(source_id="mem0", agent_id="agent123", mem0_client=mock_mem0_client) - ctx = _make_context([Message(role="user", text="Hello")]) + added = ctx.context_messages["mem0"] + assert "remembered fact" in added[0].text # type: ignore[operator] + async def test_search_query_combines_input_messages(self, mock_mem0_client: AsyncMock) -> None: + """Multiple input messages are joined for the search query.""" mock_mem0_client.search.return_value = [] + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + session = AgentSession(session_id="test-session") + ctx = SessionContext( + input_messages=[ + Message(role="user", text="Hello"), + Message(role="user", text="World"), + ], + session_id="s1", + ) - await provider.before_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) - - call_args = mock_mem0_client.search.call_args - assert call_args.kwargs["filters"] == {"agent_id": "agent123", "run_id": "test-session"} + await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] - async def test_before_run_with_session_id_in_filters( - self, mock_mem0_client: AsyncMock, mock_agent: AsyncMock, session: AgentSession - ) -> None: - """Test before_run includes session_id as run_id in search filters.""" - provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) - ctx = _make_context([Message(role="user", text="Hello")], session_id="my-session") + call_kwargs = mock_mem0_client.search.call_args.kwargs + assert call_kwargs["query"] == "Hello\nWorld" - mock_mem0_client.search.return_value = [] - await provider.before_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) +# -- after_run tests ----------------------------------------------------------- - call_args = mock_mem0_client.search.call_args - assert call_args.kwargs["filters"] == {"user_id": "user123", "run_id": "my-session"} - async def test_before_run_no_memories_does_not_add_messages( - self, mock_mem0_client: AsyncMock, mock_agent: AsyncMock, session: AgentSession - ) -> None: - """Test that no memories does not add context messages.""" - provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) - ctx = _make_context([Message(role="user", text="Hello")]) +class TestAfterRun: + """Test after_run hook.""" - mock_mem0_client.search.return_value = [] + async def test_stores_input_and_response(self, mock_mem0_client: AsyncMock) -> None: + """Stores input+response messages to mem0 via client.add.""" + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[Message(role="user", text="question")], session_id="s1") + ctx._response = AgentResponse(messages=[Message(role="assistant", text="answer")]) - await provider.before_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) + await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] - context_messages = ctx.get_messages() - assert len(context_messages) == 0 + mock_mem0_client.add.assert_awaited_once() + call_kwargs = mock_mem0_client.add.call_args.kwargs + assert call_kwargs["messages"] == [ + {"role": "user", "content": "question"}, + {"role": "assistant", "content": "answer"}, + ] + assert call_kwargs["user_id"] == "u1" + assert call_kwargs["run_id"] == "s1" + + async def test_only_stores_user_assistant_system(self, mock_mem0_client: AsyncMock) -> None: + """Only stores user/assistant/system messages with text.""" + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + session = AgentSession(session_id="test-session") + ctx = SessionContext( + input_messages=[ + Message(role="user", text="hello"), + Message(role="tool", text="tool output"), + ], + session_id="s1", + ) + ctx._response = AgentResponse(messages=[Message(role="assistant", text="reply")]) + + await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + call_kwargs = mock_mem0_client.add.call_args.kwargs + roles = [m["role"] for m in call_kwargs["messages"]] + assert "tool" not in roles + assert roles == ["user", "assistant"] + + async def test_skips_empty_messages(self, mock_mem0_client: AsyncMock) -> None: + """Skips messages with empty text.""" + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + session = AgentSession(session_id="test-session") + ctx = SessionContext( + input_messages=[ + Message(role="user", text=""), + Message(role="user", text=" "), + ], + session_id="s1", + ) + ctx._response = AgentResponse(messages=[]) - async def test_before_run_empty_input_text_skips_search( - self, mock_mem0_client: AsyncMock, mock_agent: AsyncMock, session: AgentSession - ) -> None: - """Test that empty input text skips the search entirely.""" - provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) - ctx = _make_context([Message(role="user", text=""), Message(role="user", text=" ")]) + await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] - await provider.before_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) + mock_mem0_client.add.assert_not_awaited() - mock_mem0_client.search.assert_not_called() + async def test_uses_session_id_as_run_id(self, mock_mem0_client: AsyncMock) -> None: + """Uses session_id as run_id.""" + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[Message(role="user", text="hi")], session_id="my-session") + ctx._response = AgentResponse(messages=[Message(role="assistant", text="hey")]) - async def test_before_run_filters_empty_message_text( - self, mock_mem0_client: AsyncMock, mock_agent: AsyncMock, session: AgentSession - ) -> None: - """Test that empty message text is filtered out from query.""" - provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) - messages = [ - Message(role="user", text=""), - Message(role="user", text="Valid message"), - Message(role="user", text=" "), - ] - ctx = _make_context(messages) + await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] - mock_mem0_client.search.return_value = [] + assert mock_mem0_client.add.call_args.kwargs["run_id"] == "my-session" - await provider.before_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) + async def test_validates_filters(self, mock_mem0_client: AsyncMock) -> None: + """Raises ServiceInitializationError when no filters.""" + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[Message(role="user", text="hi")], session_id="s1") + ctx._response = AgentResponse(messages=[Message(role="assistant", text="hey")]) - call_args = mock_mem0_client.search.call_args - assert call_args.kwargs["query"] == "Valid message" + with pytest.raises(ServiceInitializationError, match="At least one of the filters"): + await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] - async def test_before_run_custom_context_prompt( - self, mock_mem0_client: AsyncMock, mock_agent: AsyncMock, session: AgentSession - ) -> None: - """Test before_run with custom context prompt.""" - custom_prompt = "## Custom Context\nRemember these details:" + async def test_stores_with_application_id_metadata(self, mock_mem0_client: AsyncMock) -> None: + """application_id is passed in metadata.""" provider = Mem0ContextProvider( - source_id="mem0", - user_id="user123", - context_prompt=custom_prompt, - mem0_client=mock_mem0_client, + source_id="mem0", mem0_client=mock_mem0_client, user_id="u1", application_id="app1" ) - ctx = _make_context([Message(role="user", text="Hello")]) + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[Message(role="user", text="hi")], session_id="s1") + ctx._response = AgentResponse(messages=[]) - mock_mem0_client.search.return_value = [{"memory": "Test memory"}] + await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] - await provider.before_run(agent=mock_agent, session=session, context=ctx, state=_empty_state()) + assert mock_mem0_client.add.call_args.kwargs["metadata"] == {"application_id": "app1"} - context_messages = ctx.get_messages() - expected_text = "## Custom Context\nRemember these details:\nTest memory" - assert len(context_messages) > 0 - assert context_messages[0].text == expected_text +# -- _validate_filters tests -------------------------------------------------- -class TestMem0ContextProviderValidation: - """Test validation methods.""" - def test_validate_filters_fails_without_any_filter(self, mock_mem0_client: AsyncMock) -> None: - """Test validation failure when no filters are set.""" - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) +class TestValidateFilters: + """Test _validate_filters method.""" - with pytest.raises(ServiceInitializationError) as exc_info: + def test_raises_when_no_filters(self, mock_mem0_client: AsyncMock) -> None: + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) + with pytest.raises(ServiceInitializationError, match="At least one of the filters"): provider._validate_filters() - assert "At least one of the filters" in str(exc_info.value) + def test_passes_with_user_id(self, mock_mem0_client: AsyncMock) -> None: + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + provider._validate_filters() # should not raise - def test_validate_filters_succeeds_with_user_id(self, mock_mem0_client: AsyncMock) -> None: - """Test validation succeeds with user_id set.""" - provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) - provider._validate_filters() # Should not raise + def test_passes_with_agent_id(self, mock_mem0_client: AsyncMock) -> None: + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, agent_id="a1") + provider._validate_filters() - def test_validate_filters_succeeds_with_agent_id(self, mock_mem0_client: AsyncMock) -> None: - """Test validation succeeds with agent_id set.""" - provider = Mem0ContextProvider(source_id="mem0", agent_id="agent123", mem0_client=mock_mem0_client) - provider._validate_filters() # Should not raise + def test_passes_with_application_id(self, mock_mem0_client: AsyncMock) -> None: + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, application_id="app1") + provider._validate_filters() - def test_validate_filters_succeeds_with_application_id(self, mock_mem0_client: AsyncMock) -> None: - """Test validation succeeds with application_id set.""" - provider = Mem0ContextProvider(source_id="mem0", application_id="app123", mem0_client=mock_mem0_client) - provider._validate_filters() # Should not raise +# -- _build_filters tests ----------------------------------------------------- -class TestMem0ContextProviderBuildFilters: - """Test the _build_filters method.""" - def test_build_filters_with_user_id_only(self, mock_mem0_client: AsyncMock) -> None: - """Test building filters with only user_id.""" - provider = Mem0ContextProvider(source_id="mem0", user_id="user123", mem0_client=mock_mem0_client) +class TestBuildFilters: + """Test _build_filters method.""" - filters = provider._build_filters() - assert filters == {"user_id": "user123"} + def test_user_id_only(self, mock_mem0_client: AsyncMock) -> None: + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + assert provider._build_filters() == {"user_id": "u1"} - def test_build_filters_with_all_parameters(self, mock_mem0_client: AsyncMock) -> None: - """Test building filters with all initialization parameters.""" + def test_all_params(self, mock_mem0_client: AsyncMock) -> None: provider = Mem0ContextProvider( source_id="mem0", - user_id="user123", - agent_id="agent456", - application_id="app999", mem0_client=mock_mem0_client, + user_id="u1", + agent_id="a1", + application_id="app1", ) - - filters = provider._build_filters() - assert filters == { - "user_id": "user123", - "agent_id": "agent456", - "app_id": "app999", + assert provider._build_filters(session_id="sess1") == { + "user_id": "u1", + "agent_id": "a1", + "run_id": "sess1", + "app_id": "app1", } - def test_build_filters_excludes_none_values(self, mock_mem0_client: AsyncMock) -> None: - """Test that None values are excluded from filters.""" - provider = Mem0ContextProvider( - source_id="mem0", - user_id="user123", - agent_id=None, - application_id=None, - mem0_client=mock_mem0_client, - ) - + def test_excludes_none_values(self, mock_mem0_client: AsyncMock) -> None: + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") filters = provider._build_filters() - assert filters == {"user_id": "user123"} assert "agent_id" not in filters + assert "run_id" not in filters assert "app_id" not in filters - def test_build_filters_with_session_id(self, mock_mem0_client: AsyncMock) -> None: - """Test that session_id is included as run_id in filters.""" - provider = Mem0ContextProvider( - source_id="mem0", - user_id="user123", - mem0_client=mock_mem0_client, - ) - - filters = provider._build_filters(session_id="session-123") - assert filters == { - "user_id": "user123", - "run_id": "session-123", - } + def test_session_id_mapped_to_run_id(self, mock_mem0_client: AsyncMock) -> None: + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + filters = provider._build_filters(session_id="s99") + assert filters["run_id"] == "s99" - def test_build_filters_returns_empty_dict_when_no_parameters(self, mock_mem0_client: AsyncMock) -> None: - """Test that _build_filters returns an empty dict when no parameters are set.""" + def test_empty_when_no_params(self, mock_mem0_client: AsyncMock) -> None: provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) + assert provider._build_filters() == {} - filters = provider._build_filters() - assert filters == {} - - -class TestMem0Telemetry: - """Test telemetry configuration for Mem0.""" - - def test_mem0_telemetry_disabled_by_default(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Test that MEM0_TELEMETRY is set to 'false' by default when importing the package.""" - # Ensure MEM0_TELEMETRY is not set before importing the module under test - monkeypatch.delenv("MEM0_TELEMETRY", raising=False) - - # Remove cached modules to force re-import and trigger module-level initialization - modules_to_remove = [key for key in sys.modules if key.startswith("agent_framework_mem0")] - for mod in modules_to_remove: - del sys.modules[mod] - - # Import (and reload) the module so that it can set MEM0_TELEMETRY when unset - import agent_framework_mem0 - importlib.reload(agent_framework_mem0) +# -- Context manager tests ----------------------------------------------------- - # The environment variable should be set to "false" after importing - assert os.environ.get("MEM0_TELEMETRY") == "false" - def test_mem0_telemetry_respects_user_setting(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Test that user-set MEM0_TELEMETRY value is not overwritten.""" - # Remove cached modules to force re-import - modules_to_remove = [key for key in sys.modules if key.startswith("agent_framework_mem0")] - for mod in modules_to_remove: - del sys.modules[mod] +class TestContextManager: + """Test __aenter__/__aexit__ delegation.""" - # Set user preference before import - monkeypatch.setenv("MEM0_TELEMETRY", "true") + async def test_aenter_delegates_to_client(self, mock_mem0_client: AsyncMock) -> None: + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + result = await provider.__aenter__() + assert result is provider + mock_mem0_client.__aenter__.assert_awaited_once() - # Re-import the module - import agent_framework_mem0 + async def test_aexit_closes_auto_created_client(self, mock_mem0_client: AsyncMock) -> None: + """Auto-created clients (_should_close_client=True) are closed on exit.""" + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + provider._should_close_client = True + await provider.__aexit__(None, None, None) + mock_mem0_client.__aexit__.assert_awaited_once() - importlib.reload(agent_framework_mem0) + async def test_aexit_does_not_close_provided_client(self, mock_mem0_client: AsyncMock) -> None: + """Provided clients (_should_close_client=False) are NOT closed on exit.""" + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + assert provider._should_close_client is False + await provider.__aexit__(None, None, None) + mock_mem0_client.__aexit__.assert_not_awaited() - # User setting should be preserved - assert os.environ.get("MEM0_TELEMETRY") == "true" + async def test_async_with_syntax(self, mock_mem0_client: AsyncMock) -> None: + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + async with provider as p: + assert p is provider diff --git a/python/packages/mem0/tests/test_mem0_new_context_provider.py b/python/packages/mem0/tests/test_mem0_new_context_provider.py deleted file mode 100644 index 96a70c2beb..0000000000 --- a/python/packages/mem0/tests/test_mem0_new_context_provider.py +++ /dev/null @@ -1,352 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. -# pyright: reportPrivateUsage=false - -from __future__ import annotations - -from unittest.mock import AsyncMock, patch - -import pytest -from agent_framework import AgentResponse, Message -from agent_framework._sessions import AgentSession, SessionContext -from agent_framework.exceptions import ServiceInitializationError - -from agent_framework_mem0._context_provider import Mem0ContextProvider - - -@pytest.fixture -def mock_mem0_client() -> AsyncMock: - """Create a mock Mem0 AsyncMemoryClient.""" - from mem0 import AsyncMemoryClient - - mock_client = AsyncMock(spec=AsyncMemoryClient) - mock_client.add = AsyncMock() - mock_client.search = AsyncMock() - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock() - return mock_client - - -# -- Initialization tests ------------------------------------------------------ - - -class TestInit: - """Test Mem0ContextProvider initialization.""" - - def test_init_with_all_params(self, mock_mem0_client: AsyncMock) -> None: - provider = Mem0ContextProvider( - source_id="mem0", - mem0_client=mock_mem0_client, - api_key="key-123", - application_id="app1", - agent_id="agent1", - user_id="user1", - context_prompt="Custom prompt", - ) - assert provider.source_id == "mem0" - assert provider.api_key == "key-123" - assert provider.application_id == "app1" - assert provider.agent_id == "agent1" - assert provider.user_id == "user1" - assert provider.context_prompt == "Custom prompt" - assert provider.mem0_client is mock_mem0_client - assert provider._should_close_client is False - - def test_init_default_context_prompt(self, mock_mem0_client: AsyncMock) -> None: - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") - assert provider.context_prompt == Mem0ContextProvider.DEFAULT_CONTEXT_PROMPT - - def test_init_auto_creates_client_when_none(self) -> None: - """When no client is provided, a default AsyncMemoryClient is created and flagged for closing.""" - with ( - patch("mem0.client.main.AsyncMemoryClient.__init__", return_value=None) as mock_init, - patch("mem0.client.main.AsyncMemoryClient._validate_api_key", return_value=None), - ): - provider = Mem0ContextProvider(source_id="mem0", api_key="test-key", user_id="u1") - mock_init.assert_called_once_with(api_key="test-key") - assert provider._should_close_client is True - - def test_provided_client_not_flagged_for_close(self, mock_mem0_client: AsyncMock) -> None: - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") - assert provider._should_close_client is False - - -# -- before_run tests ---------------------------------------------------------- - - -class TestBeforeRun: - """Test before_run hook.""" - - async def test_memories_added_to_context(self, mock_mem0_client: AsyncMock) -> None: - """Mocked mem0 search returns memories → messages added to context with prompt.""" - mock_mem0_client.search.return_value = [ - {"memory": "User likes Python"}, - {"memory": "User prefers dark mode"}, - ] - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") - session = AgentSession(session_id="test-session") - ctx = SessionContext(input_messages=[Message(role="user", text="Hello")], session_id="s1") - - await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] - - mock_mem0_client.search.assert_awaited_once() - assert "mem0" in ctx.context_messages - added = ctx.context_messages["mem0"] - assert len(added) == 1 - assert "User likes Python" in added[0].text # type: ignore[operator] - assert "User prefers dark mode" in added[0].text # type: ignore[operator] - assert provider.context_prompt in added[0].text # type: ignore[operator] - - async def test_empty_input_skips_search(self, mock_mem0_client: AsyncMock) -> None: - """Empty input messages → no search performed.""" - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") - session = AgentSession(session_id="test-session") - ctx = SessionContext(input_messages=[Message(role="user", text="")], session_id="s1") - - await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] - - mock_mem0_client.search.assert_not_awaited() - assert "mem0" not in ctx.context_messages - - async def test_empty_search_results_no_messages(self, mock_mem0_client: AsyncMock) -> None: - """Empty search results → no messages added.""" - mock_mem0_client.search.return_value = [] - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") - session = AgentSession(session_id="test-session") - ctx = SessionContext(input_messages=[Message(role="user", text="test")], session_id="s1") - - await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] - - assert "mem0" not in ctx.context_messages - - async def test_validates_filters_before_search(self, mock_mem0_client: AsyncMock) -> None: - """Raises ServiceInitializationError when no filters.""" - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) - session = AgentSession(session_id="test-session") - ctx = SessionContext(input_messages=[Message(role="user", text="test")], session_id="s1") - - with pytest.raises(ServiceInitializationError, match="At least one of the filters"): - await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] - - async def test_v1_1_response_format(self, mock_mem0_client: AsyncMock) -> None: - """Search response in v1.1 dict format with 'results' key.""" - mock_mem0_client.search.return_value = {"results": [{"memory": "remembered fact"}]} - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") - session = AgentSession(session_id="test-session") - ctx = SessionContext(input_messages=[Message(role="user", text="test")], session_id="s1") - - await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] - - added = ctx.context_messages["mem0"] - assert "remembered fact" in added[0].text # type: ignore[operator] - - async def test_search_query_combines_input_messages(self, mock_mem0_client: AsyncMock) -> None: - """Multiple input messages are joined for the search query.""" - mock_mem0_client.search.return_value = [] - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") - session = AgentSession(session_id="test-session") - ctx = SessionContext( - input_messages=[ - Message(role="user", text="Hello"), - Message(role="user", text="World"), - ], - session_id="s1", - ) - - await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] - - call_kwargs = mock_mem0_client.search.call_args.kwargs - assert call_kwargs["query"] == "Hello\nWorld" - - -# -- after_run tests ----------------------------------------------------------- - - -class TestAfterRun: - """Test after_run hook.""" - - async def test_stores_input_and_response(self, mock_mem0_client: AsyncMock) -> None: - """Stores input+response messages to mem0 via client.add.""" - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") - session = AgentSession(session_id="test-session") - ctx = SessionContext(input_messages=[Message(role="user", text="question")], session_id="s1") - ctx._response = AgentResponse(messages=[Message(role="assistant", text="answer")]) - - await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] - - mock_mem0_client.add.assert_awaited_once() - call_kwargs = mock_mem0_client.add.call_args.kwargs - assert call_kwargs["messages"] == [ - {"role": "user", "content": "question"}, - {"role": "assistant", "content": "answer"}, - ] - assert call_kwargs["user_id"] == "u1" - assert call_kwargs["run_id"] == "s1" - - async def test_only_stores_user_assistant_system(self, mock_mem0_client: AsyncMock) -> None: - """Only stores user/assistant/system messages with text.""" - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") - session = AgentSession(session_id="test-session") - ctx = SessionContext( - input_messages=[ - Message(role="user", text="hello"), - Message(role="tool", text="tool output"), - ], - session_id="s1", - ) - ctx._response = AgentResponse(messages=[Message(role="assistant", text="reply")]) - - await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] - - call_kwargs = mock_mem0_client.add.call_args.kwargs - roles = [m["role"] for m in call_kwargs["messages"]] - assert "tool" not in roles - assert roles == ["user", "assistant"] - - async def test_skips_empty_messages(self, mock_mem0_client: AsyncMock) -> None: - """Skips messages with empty text.""" - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") - session = AgentSession(session_id="test-session") - ctx = SessionContext( - input_messages=[ - Message(role="user", text=""), - Message(role="user", text=" "), - ], - session_id="s1", - ) - ctx._response = AgentResponse(messages=[]) - - await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] - - mock_mem0_client.add.assert_not_awaited() - - async def test_uses_session_id_as_run_id(self, mock_mem0_client: AsyncMock) -> None: - """Uses session_id as run_id.""" - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") - session = AgentSession(session_id="test-session") - ctx = SessionContext(input_messages=[Message(role="user", text="hi")], session_id="my-session") - ctx._response = AgentResponse(messages=[Message(role="assistant", text="hey")]) - - await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] - - assert mock_mem0_client.add.call_args.kwargs["run_id"] == "my-session" - - async def test_validates_filters(self, mock_mem0_client: AsyncMock) -> None: - """Raises ServiceInitializationError when no filters.""" - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) - session = AgentSession(session_id="test-session") - ctx = SessionContext(input_messages=[Message(role="user", text="hi")], session_id="s1") - ctx._response = AgentResponse(messages=[Message(role="assistant", text="hey")]) - - with pytest.raises(ServiceInitializationError, match="At least one of the filters"): - await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] - - async def test_stores_with_application_id_metadata(self, mock_mem0_client: AsyncMock) -> None: - """application_id is passed in metadata.""" - provider = Mem0ContextProvider( - source_id="mem0", mem0_client=mock_mem0_client, user_id="u1", application_id="app1" - ) - session = AgentSession(session_id="test-session") - ctx = SessionContext(input_messages=[Message(role="user", text="hi")], session_id="s1") - ctx._response = AgentResponse(messages=[]) - - await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] - - assert mock_mem0_client.add.call_args.kwargs["metadata"] == {"application_id": "app1"} - - -# -- _validate_filters tests -------------------------------------------------- - - -class TestValidateFilters: - """Test _validate_filters method.""" - - def test_raises_when_no_filters(self, mock_mem0_client: AsyncMock) -> None: - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) - with pytest.raises(ServiceInitializationError, match="At least one of the filters"): - provider._validate_filters() - - def test_passes_with_user_id(self, mock_mem0_client: AsyncMock) -> None: - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") - provider._validate_filters() # should not raise - - def test_passes_with_agent_id(self, mock_mem0_client: AsyncMock) -> None: - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, agent_id="a1") - provider._validate_filters() - - def test_passes_with_application_id(self, mock_mem0_client: AsyncMock) -> None: - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, application_id="app1") - provider._validate_filters() - - -# -- _build_filters tests ----------------------------------------------------- - - -class TestBuildFilters: - """Test _build_filters method.""" - - def test_user_id_only(self, mock_mem0_client: AsyncMock) -> None: - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") - assert provider._build_filters() == {"user_id": "u1"} - - def test_all_params(self, mock_mem0_client: AsyncMock) -> None: - provider = Mem0ContextProvider( - source_id="mem0", - mem0_client=mock_mem0_client, - user_id="u1", - agent_id="a1", - application_id="app1", - ) - assert provider._build_filters(session_id="sess1") == { - "user_id": "u1", - "agent_id": "a1", - "run_id": "sess1", - "app_id": "app1", - } - - def test_excludes_none_values(self, mock_mem0_client: AsyncMock) -> None: - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") - filters = provider._build_filters() - assert "agent_id" not in filters - assert "run_id" not in filters - assert "app_id" not in filters - - def test_session_id_mapped_to_run_id(self, mock_mem0_client: AsyncMock) -> None: - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") - filters = provider._build_filters(session_id="s99") - assert filters["run_id"] == "s99" - - def test_empty_when_no_params(self, mock_mem0_client: AsyncMock) -> None: - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) - assert provider._build_filters() == {} - - -# -- Context manager tests ----------------------------------------------------- - - -class TestContextManager: - """Test __aenter__/__aexit__ delegation.""" - - async def test_aenter_delegates_to_client(self, mock_mem0_client: AsyncMock) -> None: - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") - result = await provider.__aenter__() - assert result is provider - mock_mem0_client.__aenter__.assert_awaited_once() - - async def test_aexit_closes_auto_created_client(self, mock_mem0_client: AsyncMock) -> None: - """Auto-created clients (_should_close_client=True) are closed on exit.""" - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") - provider._should_close_client = True - await provider.__aexit__(None, None, None) - mock_mem0_client.__aexit__.assert_awaited_once() - - async def test_aexit_does_not_close_provided_client(self, mock_mem0_client: AsyncMock) -> None: - """Provided clients (_should_close_client=False) are NOT closed on exit.""" - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") - assert provider._should_close_client is False - await provider.__aexit__(None, None, None) - mock_mem0_client.__aexit__.assert_not_awaited() - - async def test_async_with_syntax(self, mock_mem0_client: AsyncMock) -> None: - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") - async with provider as p: - assert p is provider diff --git a/python/packages/redis/tests/test_new_providers.py b/python/packages/redis/tests/test_providers.py similarity index 100% rename from python/packages/redis/tests/test_new_providers.py rename to python/packages/redis/tests/test_providers.py From bf3725c7aeab6be070ed723afbc0597f59642e5f Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Feb 2026 14:53:42 +0100 Subject: [PATCH 09/28] refactor: rewrite SlidingWindowChatMessageStore as SlidingWindowHistoryProvider(InMemoryHistoryProvider) --- .../__init__.py | 2 +- .../agent_framework_azure_ai/_chat_client.py | 2 +- .../claude/tests/test_claude_agent.py | 7 - .../packages/core/agent_framework/_agents.py | 2 +- .../core/agent_framework/_workflows/_agent.py | 73 +----- .../openai/_assistant_provider.py | 2 +- .../packages/core/tests/core/test_agents.py | 13 +- .../agent_framework_devui/_conversations.py | 8 +- .../agent_framework_durabletask/__init__.py | 2 +- .../agent_framework_durabletask/_shim.py | 3 + .../_sliding_window.py | 71 +++--- .../tau2/agent_framework_lab_tau2/runner.py | 21 +- .../lab/tau2/tests/test_sliding_window.py | 234 +++++++----------- .../orchestrations/tests/test_handoff.py | 2 +- 14 files changed, 157 insertions(+), 285 deletions(-) diff --git a/python/packages/azure-ai-search/agent_framework_azure_ai_search/__init__.py b/python/packages/azure-ai-search/agent_framework_azure_ai_search/__init__.py index e8782e2117..9610be5774 100644 --- a/python/packages/azure-ai-search/agent_framework_azure_ai_search/__init__.py +++ b/python/packages/azure-ai-search/agent_framework_azure_ai_search/__init__.py @@ -10,7 +10,7 @@ __version__ = "0.0.0" # Fallback for development mode __all__ = [ - "AzureAISearchSettings", "AzureAISearchContextProvider", + "AzureAISearchSettings", "__version__", ] diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py index 3b3d8b53d2..186cee6d1f 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py @@ -14,8 +14,8 @@ AGENT_FRAMEWORK_USER_AGENT, Agent, Annotation, - BaseContextProvider, BaseChatClient, + BaseContextProvider, ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer, ChatOptions, diff --git a/python/packages/claude/tests/test_claude_agent.py b/python/packages/claude/tests/test_claude_agent.py index 13e625b793..88b63389e8 100644 --- a/python/packages/claude/tests/test_claude_agent.py +++ b/python/packages/claude/tests/test_claude_agent.py @@ -453,13 +453,6 @@ def test_create_session_with_service_session_id(self) -> None: session = agent.create_session(session_id="existing-session-123") assert isinstance(session, AgentSession) - def test_session_inherits_context_provider(self) -> None: - """Test that session inherits context provider.""" - mock_provider = MagicMock() - agent = ClaudeAgent(context_providers=[mock_provider]) - session = agent.create_session() - assert mock_provider in agent.context_providers - async def test_ensure_session_creates_client(self) -> None: """Test _ensure_session creates client when not started.""" with patch("agent_framework_claude._agent.ClaudeSDKClient") as mock_client_class: diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 2ccb657b6d..4d4a8ba221 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -1098,7 +1098,7 @@ async def _run_after_providers( state = session.state if session else {} for provider in reversed(self.context_providers): await provider.after_run( - agent=self, + agent=self, # type: ignore[arg-type] session=session, # type: ignore[arg-type] context=context, state=state, diff --git a/python/packages/core/agent_framework/_workflows/_agent.py b/python/packages/core/agent_framework/_workflows/_agent.py index 961ac42b72..fb2f8f2904 100644 --- a/python/packages/core/agent_framework/_workflows/_agent.py +++ b/python/packages/core/agent_framework/_workflows/_agent.py @@ -18,6 +18,7 @@ BaseAgent, Content, Message, + ResponseStream, UsageDetails, ) @@ -157,7 +158,7 @@ def run( checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, **kwargs: Any, - ) -> AsyncIterable[AgentResponseUpdate] | Awaitable[AgentResponse]: + ) -> ResponseStream[AgentResponseUpdate, AgentResponse] | Awaitable[AgentResponse]: """Get a response from the workflow agent. Args: @@ -184,65 +185,19 @@ def run( or AgentResponseUpdate objects. Request info events (type='request_info') will be converted to function call and approval request contents. """ - if stream: - return self._run_streaming( - messages=messages, - session=session, - checkpoint_id=checkpoint_id, - checkpoint_storage=checkpoint_storage, - **kwargs, - ) - return self._run_non_streaming( - messages=messages, - session=session, - checkpoint_id=checkpoint_id, - checkpoint_storage=checkpoint_storage, - **kwargs, - ) - - async def _run_non_streaming( - self, - messages: str | Message | list[str] | list[Message] | None = None, - *, - session: AgentSession | None = None, - checkpoint_id: str | None = None, - checkpoint_storage: CheckpointStorage | None = None, - **kwargs: Any, - ) -> AgentResponse: - """Internal non-streaming implementation.""" input_messages = normalize_messages_input(messages) response_id = str(uuid.uuid4()) - - response = await self._run_impl( - input_messages, response_id, session, checkpoint_id, checkpoint_storage, **kwargs - ) - - return response - - async def _run_streaming( - self, - messages: str | Message | list[str] | list[Message] | None = None, - *, - session: AgentSession | None = None, - checkpoint_id: str | None = None, - checkpoint_storage: CheckpointStorage | None = None, - **kwargs: Any, - ) -> AsyncIterable[AgentResponseUpdate]: - """Internal streaming implementation. - - Yields AgentResponseUpdate objects. Output events (type='output') from the workflow - are converted to updates. Request info events (type='request_info') are converted - to function call and approval request contents. - """ + if stream: + return ResponseStream( + self._run_stream_impl( + input_messages, response_id, session, checkpoint_id, checkpoint_storage, **kwargs + ), + finalizer=AgentResponse.from_updates, + ) input_messages = normalize_messages_input(messages) - response_updates: list[AgentResponseUpdate] = [] response_id = str(uuid.uuid4()) - async for update in self._run_stream_impl( - input_messages, response_id, session, checkpoint_id, checkpoint_storage, **kwargs - ): - response_updates.append(update) - yield update + return self._run_impl(input_messages, response_id, session, checkpoint_id, checkpoint_storage, **kwargs) async def _run_impl( self, @@ -268,9 +223,7 @@ async def _run_impl( An AgentResponse representing the workflow execution results. """ output_events: list[WorkflowEvent[Any]] = [] - async for event in self._run_core( - input_messages, checkpoint_id, checkpoint_storage, streaming=False, **kwargs - ): + async for event in self._run_core(input_messages, checkpoint_id, checkpoint_storage, streaming=False, **kwargs): if event.type == "output" or event.type == "request_info": output_events.append(event) @@ -299,9 +252,7 @@ async def _run_stream_impl( Yields: AgentResponseUpdate objects representing the workflow execution progress. """ - async for event in self._run_core( - input_messages, checkpoint_id, checkpoint_storage, streaming=True, **kwargs - ): + async for event in self._run_core(input_messages, checkpoint_id, checkpoint_storage, streaming=True, **kwargs): updates = self._convert_workflow_event_to_agent_response_updates(response_id, event) for update in updates: yield update diff --git a/python/packages/core/agent_framework/openai/_assistant_provider.py b/python/packages/core/agent_framework/openai/_assistant_provider.py index a64ae87b95..8082a4ad9b 100644 --- a/python/packages/core/agent_framework/openai/_assistant_provider.py +++ b/python/packages/core/agent_framework/openai/_assistant_provider.py @@ -13,8 +13,8 @@ from agent_framework._settings import SecretString, load_settings from .._agents import Agent -from .._sessions import BaseContextProvider from .._middleware import MiddlewareTypes +from .._sessions import BaseContextProvider from .._tools import FunctionTool from .._types import normalize_tools from ..exceptions import ServiceInitializationError diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index cd7b2c7aba..bfa7525d7e 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. import contextlib -from collections.abc import AsyncIterable, MutableSequence, Sequence +from collections.abc import AsyncIterable, MutableSequence from typing import Any from unittest.mock import AsyncMock, MagicMock from uuid import uuid4 @@ -14,10 +14,10 @@ AgentResponse, AgentResponseUpdate, AgentSession, + BaseContextProvider, ChatOptions, ChatResponse, Content, - BaseContextProvider, FunctionTool, Message, SupportsAgentRun, @@ -173,7 +173,7 @@ async def test_chat_client_agent_update_session_messages(client: SupportsChatGet async def test_chat_client_agent_update_session_conversation_id_missing(client: SupportsChatGetResponse) -> None: agent = Agent(client=client) - session = AgentSession(service_session_id="123") + session = agent.get_session(service_session_id="123") # With the session-based API, service_session_id is managed directly on the session assert session.service_session_id == "123" @@ -835,11 +835,8 @@ async def test_agent_get_session_with_service_session_id( assert session.service_session_id == "test-thread-123" -@pytest.mark.asyncio -async def test_agent_session_from_dict(chat_client_base: SupportsChatGetResponse, tool_tool: FunctionTool): +def test_agent_session_from_dict(chat_client_base: SupportsChatGetResponse, tool_tool: FunctionTool): """Test AgentSession.from_dict restores a session from serialized state.""" - agent = Agent(client=chat_client_base, tools=[tool_tool]) - # Create serialized session state serialized_state = { "type": "session", @@ -861,7 +858,6 @@ async def test_agent_session_from_dict(chat_client_base: SupportsChatGetResponse # region Test Agent initialization edge cases - def test_chat_agent_calls_update_agent_name_on_client(): """Test that Agent calls _update_agent_name_and_description on client if available.""" mock_client = MagicMock() @@ -937,5 +933,4 @@ async def before_run(self, *, agent, session, context, state): assert options.get("instructions") == "Context-provided instructions" - # endregion diff --git a/python/packages/devui/agent_framework_devui/_conversations.py b/python/packages/devui/agent_framework_devui/_conversations.py index 79ec7c7034..2778c2da78 100644 --- a/python/packages/devui/agent_framework_devui/_conversations.py +++ b/python/packages/devui/agent_framework_devui/_conversations.py @@ -401,9 +401,7 @@ async def list_items( if media_type and media_type.startswith("image/"): # Convert to ResponseInputImage - message_contents.append( - ResponseInputImage(type="input_image", image_url=uri, detail="auto") - ) + message_contents.append(ResponseInputImage(type="input_image", image_url=uri, detail="auto")) else: # Convert to ResponseInputFile # Extract filename from URI if possible @@ -411,9 +409,7 @@ async def list_items( if media_type == "application/pdf": filename = "document.pdf" - message_contents.append( - ResponseInputFile(type="input_file", file_url=uri, filename=filename) - ) + message_contents.append(ResponseInputFile(type="input_file", file_url=uri, filename=filename)) elif content_type == "function_call": # Function call - create separate ConversationItem diff --git a/python/packages/durabletask/agent_framework_durabletask/__init__.py b/python/packages/durabletask/agent_framework_durabletask/__init__.py index bb0da56af4..a518b5ad23 100644 --- a/python/packages/durabletask/agent_framework_durabletask/__init__.py +++ b/python/packages/durabletask/agent_framework_durabletask/__init__.py @@ -79,6 +79,7 @@ "DurableAIAgentOrchestrationContext", "DurableAIAgentWorker", "DurableAgentExecutor", + "DurableAgentSession", "DurableAgentState", "DurableAgentStateContent", "DurableAgentStateData", @@ -99,7 +100,6 @@ "DurableAgentStateUriContent", "DurableAgentStateUsage", "DurableAgentStateUsageContent", - "DurableAgentSession", "DurableStateFields", "RunRequest", "__version__", diff --git a/python/packages/durabletask/agent_framework_durabletask/_shim.py b/python/packages/durabletask/agent_framework_durabletask/_shim.py index 8ad40e34b6..10352d8bb7 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_shim.py +++ b/python/packages/durabletask/agent_framework_durabletask/_shim.py @@ -136,6 +136,9 @@ def create_session(self, **kwargs: Any) -> DurableAgentSession: """Create a new agent session via the provider.""" return self._executor.get_new_session(self.name, **kwargs) + def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession: + return self._executor.get_new_session(self.name, **kwargs) + def _normalize_messages(self, messages: str | Message | list[str] | list[Message] | None) -> str: """Convert supported message inputs to a single string. diff --git a/python/packages/lab/tau2/agent_framework_lab_tau2/_sliding_window.py b/python/packages/lab/tau2/agent_framework_lab_tau2/_sliding_window.py index ad4328ff21..1777f7683b 100644 --- a/python/packages/lab/tau2/agent_framework_lab_tau2/_sliding_window.py +++ b/python/packages/lab/tau2/agent_framework_lab_tau2/_sliding_window.py @@ -1,63 +1,61 @@ # Copyright (c) Microsoft. All rights reserved. import json -from collections.abc import Sequence from typing import Any import tiktoken -from agent_framework import ChatMessageStore, Message +from agent_framework import InMemoryHistoryProvider, Message from loguru import logger -class SlidingWindowChatMessageStore(ChatMessageStore): - """A token-aware sliding window implementation of ChatMessageStore. +class SlidingWindowHistoryProvider(InMemoryHistoryProvider): + """A token-aware sliding window implementation of InMemoryHistoryProvider. - Maintains two message lists: complete history and truncated window. - Automatically removes oldest messages when token limit is exceeded. - Also removes leading tool messages to ensure valid conversation flow. + Stores all messages in session state but returns a truncated window from + ``get_messages`` that fits within ``max_tokens``. Automatically removes + oldest messages and leading tool messages to ensure valid conversation flow. """ def __init__( self, - messages: Sequence[Message] | None = None, + source_id: str = "memory", + *, max_tokens: int = 3800, system_message: str | None = None, tool_definitions: Any | None = None, ): - super().__init__(messages=messages) - self.truncated_messages = self.messages.copy() + super().__init__(source_id) self.max_tokens = max_tokens self.system_message = system_message # Included in token count self.tool_definitions = tool_definitions # An estimation based on a commonly used vocab table self.encoding = tiktoken.get_encoding("o200k_base") - async def add_messages(self, messages: Sequence[Message]) -> None: - await super().add_messages(messages) + async def get_messages( + self, session_id: str | None, *, state: dict[str, Any] | None = None, **kwargs: Any + ) -> list[Message]: + """Retrieve messages from session state, truncated to fit within max_tokens.""" + all_messages = await super().get_messages(session_id, state=state, **kwargs) + return self._truncate(list(all_messages)) - self.truncated_messages = self.messages.copy() - self.truncate_messages() - - async def list_messages(self) -> list[Message]: - """Get the current list of messages, which may be truncated.""" - return self.truncated_messages - - async def list_all_messages(self) -> list[Message]: + async def get_all_messages(self, *, state: dict[str, Any] | None = None) -> list[Message]: """Get all messages from the store including the truncated ones.""" - return self.messages + return await super().get_messages(None, state=state) - def truncate_messages(self) -> None: - while len(self.truncated_messages) > 0 and self.get_token_count() > self.max_tokens: + def _truncate(self, messages: list[Message]) -> list[Message]: + """Truncate messages to fit within max_tokens and remove leading tool messages.""" + while len(messages) > 0 and self._get_token_count(messages) > self.max_tokens: logger.warning("Messages exceed max tokens. Truncating oldest message.") - self.truncated_messages.pop(0) + messages.pop(0) # Remove leading tool messages - while len(self.truncated_messages) > 0: - if self.truncated_messages[0].role != "tool": + while len(messages) > 0: + if messages[0].role != "tool": break logger.warning("Removing leading tool message because tool result cannot be the first message.") - self.truncated_messages.pop(0) + messages.pop(0) + return messages - def get_token_count(self) -> int: + def _get_token_count(self, messages: list[Message]) -> int: """Estimate token count for a list of messages using tiktoken. Returns: @@ -70,7 +68,7 @@ def get_token_count(self) -> int: total_tokens += len(self.encoding.encode(self.system_message)) total_tokens += 4 # Extra tokens for system message formatting - for msg in self.truncated_messages: + for msg in messages: # Add 4 tokens per message for role, formatting, etc. total_tokens += 4 @@ -87,7 +85,7 @@ def get_token_count(self) -> int: "name": content.name, "arguments": content.arguments, } - total_tokens += self.estimate_any_object_token_count(func_call_data) + total_tokens += self._estimate_any_object_token_count(func_call_data) elif content.type == "function_result": total_tokens += 4 # Serialize function result and count tokens @@ -95,19 +93,16 @@ def get_token_count(self) -> int: "call_id": content.call_id, "result": content.result, } - total_tokens += self.estimate_any_object_token_count(func_result_data) + total_tokens += self._estimate_any_object_token_count(func_result_data) else: # For other content types, serialize the whole content - total_tokens += self.estimate_any_object_token_count(content) + total_tokens += self._estimate_any_object_token_count(content) else: # Content without type, treat as text - total_tokens += self.estimate_any_object_token_count(content) + total_tokens += self._estimate_any_object_token_count(content) elif hasattr(msg, "text") and msg.text: # Simple text message - total_tokens += self.estimate_any_object_token_count(msg.text) - else: - # Skip it - pass + total_tokens += self._estimate_any_object_token_count(msg.text) if total_tokens > self.max_tokens / 2: logger.opt(colors=True).warning( @@ -122,7 +117,7 @@ def get_token_count(self) -> int: return total_tokens - def estimate_any_object_token_count(self, obj: Any) -> int: + def _estimate_any_object_token_count(self, obj: Any) -> int: try: serialized = json.dumps(obj) except Exception: diff --git a/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py b/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py index 68205c880e..5b5dc71c31 100644 --- a/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py +++ b/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py @@ -32,7 +32,7 @@ from tau2.utils.utils import get_now # type: ignore[import-untyped] from ._message_utils import flip_messages, log_messages -from ._sliding_window import SlidingWindowChatMessageStore +from ._sliding_window import SlidingWindowHistoryProvider from ._tau2_utils import convert_agent_framework_messages_to_tau2_messages, convert_tau2_tool_to_function_tool __all__ = ["ASSISTANT_AGENT_ID", "ORCHESTRATOR_ID", "USER_SIMULATOR_ID", "TaskRunner"] @@ -201,11 +201,13 @@ def assistant_agent(self, assistant_chat_client: SupportsChatGetResponse) -> Age instructions=assistant_system_prompt, tools=tools, temperature=self.assistant_sampling_temperature, - chat_message_store_factory=lambda: SlidingWindowChatMessageStore( - system_message=assistant_system_prompt, - tool_definitions=[tool.openai_schema for tool in tools], - max_tokens=self.assistant_window_size, - ), + context_providers=[ + SlidingWindowHistoryProvider( + system_message=assistant_system_prompt, + tool_definitions=[tool.openai_schema for tool in tools], + max_tokens=self.assistant_window_size, + ) + ], ) def user_simulator(self, user_simuator_chat_client: SupportsChatGetResponse, task: Task) -> Agent: @@ -354,11 +356,12 @@ async def run( # STEP 5: Ensemble the conversation history needed for evaluation. # It's coming from three parts: # 1. The initial greeting - # 2. The assistant's message store (not just the truncated window) + # 2. The assistant's session state (full history, not just the truncated window) # 3. The final user message (if any) assistant_executor = cast(AgentExecutor, self._assistant_executor) - message_store = cast(SlidingWindowChatMessageStore, assistant_executor._agent_thread.message_store) - full_conversation = [first_message] + await message_store.list_all_messages() + history_provider = cast(SlidingWindowHistoryProvider, assistant_executor._agent.context_providers[0]) + all_messages = await history_provider.get_all_messages(state=assistant_executor._session.state) + full_conversation = [first_message] + all_messages if self._final_user_message is not None: full_conversation.extend(self._final_user_message) diff --git a/python/packages/lab/tau2/tests/test_sliding_window.py b/python/packages/lab/tau2/tests/test_sliding_window.py index c991f5b568..9439c4afac 100644 --- a/python/packages/lab/tau2/tests/test_sliding_window.py +++ b/python/packages/lab/tau2/tests/test_sliding_window.py @@ -1,145 +1,121 @@ # Copyright (c) Microsoft. All rights reserved. -"""Tests for sliding window message list.""" +"""Tests for sliding window history provider.""" from unittest.mock import patch +from agent_framework import AgentSession from agent_framework._types import Content, Message -from agent_framework_lab_tau2._sliding_window import SlidingWindowChatMessageStore +from agent_framework_lab_tau2._sliding_window import SlidingWindowHistoryProvider -def test_initialization_empty(): - """Test initializing with no messages.""" - sliding_window = SlidingWindowChatMessageStore(max_tokens=1000) +def _make_state(provider: SlidingWindowHistoryProvider, messages: list[Message] | None = None) -> dict: + """Helper to create a session state dict with messages pre-loaded.""" + state: dict = {} + if messages: + state[provider.source_id] = {"messages": list(messages)} + return state - assert sliding_window.max_tokens == 1000 - assert sliding_window.system_message is None - assert sliding_window.tool_definitions is None - assert len(sliding_window.messages) == 0 - assert len(sliding_window.truncated_messages) == 0 - -def test_initialization_with_parameters(): - """Test initializing with system message and tool definitions.""" - system_msg = "You are a helpful assistant" - tool_defs = [{"name": "test_tool", "description": "A test tool"}] - - sliding_window = SlidingWindowChatMessageStore( - max_tokens=2000, system_message=system_msg, tool_definitions=tool_defs +def test_initialization(): + """Test initializing with parameters.""" + provider = SlidingWindowHistoryProvider( + max_tokens=2000, + system_message="You are a helpful assistant", + tool_definitions=[{"name": "test_tool"}], ) - assert sliding_window.max_tokens == 2000 - assert sliding_window.system_message == system_msg - assert sliding_window.tool_definitions == tool_defs - - -def test_initialization_with_messages(): - """Test initializing with existing messages.""" - messages = [ - Message(role="user", contents=[Content.from_text(text="Hello")]), - Message(role="assistant", contents=[Content.from_text(text="Hi there!")]), - ] + assert provider.max_tokens == 2000 + assert provider.system_message == "You are a helpful assistant" + assert provider.tool_definitions == [{"name": "test_tool"}] + assert provider.source_id == "memory" - sliding_window = SlidingWindowChatMessageStore(messages=messages, max_tokens=1000) - assert len(sliding_window.messages) == 2 - assert len(sliding_window.truncated_messages) == 2 +async def test_get_messages_empty(): + """Test getting messages from empty state.""" + provider = SlidingWindowHistoryProvider(max_tokens=1000) + messages = await provider.get_messages(None, state={}) + assert messages == [] -async def test_add_messages_simple(): - """Test adding messages without truncation.""" - sliding_window = SlidingWindowChatMessageStore(max_tokens=10000) # Large limit - - new_messages = [ +async def test_get_messages_simple(): + """Test getting messages without truncation.""" + provider = SlidingWindowHistoryProvider(max_tokens=10000) + msgs = [ Message(role="user", contents=[Content.from_text(text="What's the weather?")]), Message(role="assistant", contents=[Content.from_text(text="I can help with that.")]), ] + state = _make_state(provider, msgs) - await sliding_window.add_messages(new_messages) - - messages = await sliding_window.list_messages() - assert len(messages) == 2 - assert messages[0].text == "What's the weather?" - assert messages[1].text == "I can help with that." + result = await provider.get_messages(None, state=state) + assert len(result) == 2 + assert result[0].text == "What's the weather?" + assert result[1].text == "I can help with that." -async def test_list_all_messages_vs_list_messages(): - """Test difference between list_all_messages and list_messages.""" - sliding_window = SlidingWindowChatMessageStore(max_tokens=50) # Small limit to force truncation +async def test_save_and_get_messages(): + """Test saving then getting messages with truncation.""" + provider = SlidingWindowHistoryProvider(max_tokens=50) + state: dict = {} - # Add many messages to trigger truncation - messages = [ + # Save many messages + msgs = [ Message(role="user", contents=[Content.from_text(text=f"Message {i} with some content")]) for i in range(10) ] + await provider.save_messages(None, msgs, state=state) - await sliding_window.add_messages(messages) - - truncated_messages = await sliding_window.list_messages() - all_messages = await sliding_window.list_all_messages() + # get_messages returns truncated + truncated = await provider.get_messages(None, state=state) + # get_all_messages returns full history + all_msgs = await provider.get_all_messages(state=state) - # All messages should contain everything - assert len(all_messages) == 10 - - # Truncated messages should be fewer due to token limit - assert len(truncated_messages) < len(all_messages) + assert len(all_msgs) == 10 + assert len(truncated) < len(all_msgs) def test_get_token_count_basic(): """Test basic token counting.""" - sliding_window = SlidingWindowChatMessageStore(max_tokens=1000) - sliding_window.truncated_messages = [Message(role="user", contents=[Content.from_text(text="Hello")])] - - token_count = sliding_window.get_token_count() + provider = SlidingWindowHistoryProvider(max_tokens=1000) + messages = [Message(role="user", contents=[Content.from_text(text="Hello")])] - # Should be more than 0 (exact count depends on encoding) + token_count = provider._get_token_count(messages) assert token_count > 0 def test_get_token_count_with_system_message(): """Test token counting includes system message.""" - system_msg = "You are a helpful assistant" - sliding_window = SlidingWindowChatMessageStore(max_tokens=1000, system_message=system_msg) + provider = SlidingWindowHistoryProvider(max_tokens=1000, system_message="You are a helpful assistant") - # Without messages - token_count_empty = sliding_window.get_token_count() + count_empty = provider._get_token_count([]) + count_with_msg = provider._get_token_count([Message(role="user", contents=[Content.from_text(text="Hello")])]) - # Add a message - sliding_window.truncated_messages = [Message(role="user", contents=[Content.from_text(text="Hello")])] - token_count_with_message = sliding_window.get_token_count() - - # With message should be more tokens - assert token_count_with_message > token_count_empty - assert token_count_empty > 0 # System message contributes tokens + assert count_with_msg > count_empty + assert count_empty > 0 # System message contributes tokens def test_get_token_count_function_call(): """Test token counting with function calls.""" function_call = Content.from_function_call(call_id="call_123", name="test_function", arguments={"param": "value"}) + provider = SlidingWindowHistoryProvider(max_tokens=1000) - sliding_window = SlidingWindowChatMessageStore(max_tokens=1000) - sliding_window.truncated_messages = [Message(role="assistant", contents=[function_call])] - - token_count = sliding_window.get_token_count() + token_count = provider._get_token_count([Message(role="assistant", contents=[function_call])]) assert token_count > 0 def test_get_token_count_function_result(): """Test token counting with function results.""" function_result = Content.from_function_result(call_id="call_123", result={"success": True, "data": "result"}) + provider = SlidingWindowHistoryProvider(max_tokens=1000) - sliding_window = SlidingWindowChatMessageStore(max_tokens=1000) - sliding_window.truncated_messages = [Message(role="tool", contents=[function_result])] - - token_count = sliding_window.get_token_count() + token_count = provider._get_token_count([Message(role="tool", contents=[function_result])]) assert token_count > 0 @patch("agent_framework_lab_tau2._sliding_window.logger") -def test_truncate_messages_removes_old_messages(mock_logger): +def test_truncate_removes_old_messages(mock_logger): """Test that truncation removes old messages when token limit exceeded.""" - sliding_window = SlidingWindowChatMessageStore(max_tokens=20) # Very small limit + provider = SlidingWindowHistoryProvider(max_tokens=20) - # Create messages that will exceed the limit messages = [ Message( role="user", @@ -154,80 +130,45 @@ def test_truncate_messages_removes_old_messages(mock_logger): Message(role="user", contents=[Content.from_text(text="Short msg")]), ] - sliding_window.truncated_messages = messages.copy() - sliding_window.truncate_messages() - - # Should have fewer messages after truncation - assert len(sliding_window.truncated_messages) < len(messages) - - # Should have logged warnings + result = provider._truncate(list(messages)) + assert len(result) < len(messages) assert mock_logger.warning.called @patch("agent_framework_lab_tau2._sliding_window.logger") -def test_truncate_messages_removes_leading_tool_messages(mock_logger): +def test_truncate_removes_leading_tool_messages(mock_logger): """Test that truncation removes leading tool messages.""" - sliding_window = SlidingWindowChatMessageStore(max_tokens=10000) # Large limit + provider = SlidingWindowHistoryProvider(max_tokens=10000) - # Create messages starting with tool message tool_message = Message(role="tool", contents=[Content.from_function_result(call_id="call_123", result="result")]) user_message = Message(role="user", contents=[Content.from_text(text="Hello")]) - sliding_window.truncated_messages = [tool_message, user_message] - sliding_window.truncate_messages() - - # Tool message should be removed from the beginning - assert len(sliding_window.truncated_messages) == 1 - assert sliding_window.truncated_messages[0].role == "user" - - # Should have logged warning about removing tool message + result = provider._truncate([tool_message, user_message]) + assert len(result) == 1 + assert result[0].role == "user" mock_logger.warning.assert_called() -def test_estimate_any_object_token_count_dict(): - """Test token counting for dictionary objects.""" - sliding_window = SlidingWindowChatMessageStore(max_tokens=1000) - - test_dict = {"key": "value", "number": 42} - token_count = sliding_window.estimate_any_object_token_count(test_dict) - - assert token_count > 0 - - -def test_estimate_any_object_token_count_string(): - """Test token counting for string objects.""" - sliding_window = SlidingWindowChatMessageStore(max_tokens=1000) +def test_estimate_any_object_token_count(): + """Test token counting for various object types.""" + provider = SlidingWindowHistoryProvider(max_tokens=1000) - test_string = "This is a test string" - token_count = sliding_window.estimate_any_object_token_count(test_string) + assert provider._estimate_any_object_token_count({"key": "value"}) > 0 + assert provider._estimate_any_object_token_count("test string") > 0 - assert token_count > 0 - - -def test_estimate_any_object_token_count_non_serializable(): - """Test token counting for non-JSON-serializable objects.""" - sliding_window = SlidingWindowChatMessageStore(max_tokens=1000) - - # Create an object that can't be JSON serialized - class CustomObject: + # Non-serializable falls back to str() + class Custom: def __str__(self): - return "CustomObject instance" + return "Custom instance" - custom_obj = CustomObject() - token_count = sliding_window.estimate_any_object_token_count(custom_obj) - - # Should fall back to string representation - assert token_count > 0 + assert provider._estimate_any_object_token_count(Custom()) > 0 async def test_real_world_scenario(): """Test a realistic conversation scenario.""" - sliding_window = SlidingWindowChatMessageStore( - max_tokens=30, - system_message="You are a helpful assistant", # Moderate limit - ) + provider = SlidingWindowHistoryProvider(max_tokens=30, system_message="You are a helpful assistant") + state: dict = {} - # Simulate a conversation conversation = [ Message(role="user", contents=[Content.from_text(text="Hello, how are you?")]), Message( @@ -253,18 +194,13 @@ async def test_real_world_scenario(): ), ] - await sliding_window.add_messages(conversation) - - current_messages = await sliding_window.list_messages() - all_messages = await sliding_window.list_all_messages() + await provider.save_messages(None, conversation, state=state) - # All messages should be preserved - assert len(all_messages) == 6 + truncated = await provider.get_messages(None, state=state) + all_msgs = await provider.get_all_messages(state=state) - # Current messages might be truncated - assert len(current_messages) <= 6 + assert len(all_msgs) == 6 + assert len(truncated) <= 6 - # Token count should be within or close to limit - token_count = sliding_window.get_token_count() - # Allow some margin since truncation happens when exceeded - assert token_count <= sliding_window.max_tokens * 1.1 + token_count = provider._get_token_count(truncated) + assert token_count <= provider.max_tokens * 1.1 diff --git a/python/packages/orchestrations/tests/test_handoff.py b/python/packages/orchestrations/tests/test_handoff.py index c947b46524..e8778a86ca 100644 --- a/python/packages/orchestrations/tests/test_handoff.py +++ b/python/packages/orchestrations/tests/test_handoff.py @@ -7,10 +7,10 @@ import pytest from agent_framework import ( Agent, + BaseContextProvider, ChatResponse, ChatResponseUpdate, Content, - BaseContextProvider, Message, ResponseStream, WorkflowEvent, From 0eaac4f9998a8f0ece333b4ac0b11b1cb40f4051 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Feb 2026 14:58:47 +0100 Subject: [PATCH 10/28] fix: read full history from session state directly instead of reaching into provider internals --- python/packages/core/agent_framework/_agents.py | 2 +- .../lab/tau2/agent_framework_lab_tau2/_sliding_window.py | 4 ---- .../packages/lab/tau2/agent_framework_lab_tau2/runner.py | 6 ++---- python/packages/lab/tau2/tests/test_sliding_window.py | 7 +++---- 4 files changed, 6 insertions(+), 13 deletions(-) diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 4d4a8ba221..a84445678c 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -1150,7 +1150,7 @@ async def _prepare_session_and_messages( if isinstance(provider, BaseHistoryProvider) and not provider.load_messages: continue await provider.before_run( - agent=self, + agent=self, # type: ignore[arg-type] session=session, # type: ignore[arg-type] context=session_context, state=state, diff --git a/python/packages/lab/tau2/agent_framework_lab_tau2/_sliding_window.py b/python/packages/lab/tau2/agent_framework_lab_tau2/_sliding_window.py index 1777f7683b..1be7a7a318 100644 --- a/python/packages/lab/tau2/agent_framework_lab_tau2/_sliding_window.py +++ b/python/packages/lab/tau2/agent_framework_lab_tau2/_sliding_window.py @@ -38,10 +38,6 @@ async def get_messages( all_messages = await super().get_messages(session_id, state=state, **kwargs) return self._truncate(list(all_messages)) - async def get_all_messages(self, *, state: dict[str, Any] | None = None) -> list[Message]: - """Get all messages from the store including the truncated ones.""" - return await super().get_messages(None, state=state) - def _truncate(self, messages: list[Message]) -> list[Message]: """Truncate messages to fit within max_tokens and remove leading tool messages.""" while len(messages) > 0 and self._get_token_count(messages) > self.max_tokens: diff --git a/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py b/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py index 5b5dc71c31..8993df29fd 100644 --- a/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py +++ b/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py @@ -3,7 +3,6 @@ from __future__ import annotations import uuid -from typing import cast from agent_framework import ( Agent, @@ -358,9 +357,8 @@ async def run( # 1. The initial greeting # 2. The assistant's session state (full history, not just the truncated window) # 3. The final user message (if any) - assistant_executor = cast(AgentExecutor, self._assistant_executor) - history_provider = cast(SlidingWindowHistoryProvider, assistant_executor._agent.context_providers[0]) - all_messages = await history_provider.get_all_messages(state=assistant_executor._session.state) + session_state = self._assistant_executor._session.state + all_messages: list[Message] = list(session_state.get("memory", {}).get("messages", [])) full_conversation = [first_message] + all_messages if self._final_user_message is not None: full_conversation.extend(self._final_user_message) diff --git a/python/packages/lab/tau2/tests/test_sliding_window.py b/python/packages/lab/tau2/tests/test_sliding_window.py index 9439c4afac..833474eb13 100644 --- a/python/packages/lab/tau2/tests/test_sliding_window.py +++ b/python/packages/lab/tau2/tests/test_sliding_window.py @@ -4,7 +4,6 @@ from unittest.mock import patch -from agent_framework import AgentSession from agent_framework._types import Content, Message from agent_framework_lab_tau2._sliding_window import SlidingWindowHistoryProvider @@ -66,8 +65,8 @@ async def test_save_and_get_messages(): # get_messages returns truncated truncated = await provider.get_messages(None, state=state) - # get_all_messages returns full history - all_msgs = await provider.get_all_messages(state=state) + # Full history is in session state + all_msgs = state[provider.source_id]["messages"] assert len(all_msgs) == 10 assert len(truncated) < len(all_msgs) @@ -197,7 +196,7 @@ async def test_real_world_scenario(): await provider.save_messages(None, conversation, state=state) truncated = await provider.get_messages(None, state=state) - all_msgs = await provider.get_all_messages(state=state) + all_msgs = state[provider.source_id]["messages"] assert len(all_msgs) == 6 assert len(truncated) <= 6 From 3510e06e1f241297638961038e02920056eb9c91 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Feb 2026 15:18:01 +0100 Subject: [PATCH 11/28] fix: update stale .pyi stubs, sample imports, and README references for new provider types --- .../packages/core/agent_framework/_tools.py | 4 +- .../core/agent_framework/_workflows/_agent.py | 2 +- .../core/agent_framework/mem0/__init__.pyi | 4 +- .../core/agent_framework/redis/__init__.pyi | 8 ++-- .../tau2/agent_framework_lab_tau2/runner.py | 7 ++-- .../02-agents/context_providers/README.md | 39 ++++++++++--------- .../aggregate_context_provider.py | 3 +- .../azure_ai_search/README.md | 4 +- .../context_providers/redis/README.md | 4 +- .../context_providers/redis/redis_basics.py | 2 +- .../redis/redis_conversation.py | 2 +- .../context_providers/redis/redis_threads.py | 2 +- .../simple_context_provider.py | 2 +- .../custom_chat_message_store_thread.py | 2 +- .../02-agents/providers/azure_ai/README.md | 2 +- ...> function_tool_with_session_injection.py} | 14 ++++--- python/samples/03-workflows/README.md | 2 +- .../azure_ai_agents_with_shared_thread.py | 1 - .../workflow_as_agent_checkpoint.py | 1 - .../agent_with_text_search_rag/main.py | 1 - python/samples/autogen-migration/README.md | 2 +- .../semantic-kernel-migration/README.md | 2 +- 22 files changed, 55 insertions(+), 55 deletions(-) rename python/samples/02-agents/tools/{function_tool_with_thread_injection.py => function_tool_with_session_injection.py} (78%) diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 6362433892..0f0cded4a3 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -468,7 +468,7 @@ async def invoke( "chat_options", "tools", "tool_choice", - "thread", + "session", "conversation_id", "options", "response_format", @@ -1897,7 +1897,7 @@ def get_response( config=self.function_invocation_configuration, middleware_pipeline=function_middleware_pipeline, ) - filtered_kwargs = {k: v for k, v in kwargs.items() if k != "thread"} + filtered_kwargs = {k: v for k, v in kwargs.items() if k != "session"} # Make options mutable so we can update conversation_id during function invocation loop mutable_options: dict[str, Any] = dict(options) if options else {} # Remove additional_function_arguments from options passed to underlying chat client diff --git a/python/packages/core/agent_framework/_workflows/_agent.py b/python/packages/core/agent_framework/_workflows/_agent.py index fb2f8f2904..dfa078a9fa 100644 --- a/python/packages/core/agent_framework/_workflows/_agent.py +++ b/python/packages/core/agent_framework/_workflows/_agent.py @@ -135,7 +135,7 @@ def run( checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, **kwargs: Any, - ) -> AsyncIterable[AgentResponseUpdate]: ... + ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: ... @overload async def run( diff --git a/python/packages/core/agent_framework/mem0/__init__.pyi b/python/packages/core/agent_framework/mem0/__init__.pyi index 29250a02ad..18ee3bf2bd 100644 --- a/python/packages/core/agent_framework/mem0/__init__.pyi +++ b/python/packages/core/agent_framework/mem0/__init__.pyi @@ -1,11 +1,11 @@ # Copyright (c) Microsoft. All rights reserved. from agent_framework_mem0 import ( - Mem0Provider, + Mem0ContextProvider, __version__, ) __all__ = [ - "Mem0Provider", + "Mem0ContextProvider", "__version__", ] diff --git a/python/packages/core/agent_framework/redis/__init__.pyi b/python/packages/core/agent_framework/redis/__init__.pyi index 6cce35db76..fc62badb76 100644 --- a/python/packages/core/agent_framework/redis/__init__.pyi +++ b/python/packages/core/agent_framework/redis/__init__.pyi @@ -1,13 +1,13 @@ # Copyright (c) Microsoft. All rights reserved. from agent_framework_redis import ( - RedisChatMessageStore, - RedisProvider, + RedisContextProvider, + RedisHistoryProvider, __version__, ) __all__ = [ - "RedisChatMessageStore", - "RedisProvider", + "RedisContextProvider", + "RedisHistoryProvider", "__version__", ] diff --git a/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py b/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py index 8993df29fd..9437bbc08a 100644 --- a/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py +++ b/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py @@ -3,6 +3,7 @@ from __future__ import annotations import uuid +from typing import Any from agent_framework import ( Agent, @@ -357,9 +358,9 @@ async def run( # 1. The initial greeting # 2. The assistant's session state (full history, not just the truncated window) # 3. The final user message (if any) - session_state = self._assistant_executor._session.state - all_messages: list[Message] = list(session_state.get("memory", {}).get("messages", [])) - full_conversation = [first_message] + all_messages + session_state: dict[str, Any] = self._assistant_executor._session.state # type: ignore + all_messages: list[Message] = list(session_state.get("memory", {}).get("messages", [])) # type: ignore + full_conversation = [first_message, *all_messages] if self._final_user_message is not None: full_conversation.extend(self._final_user_message) diff --git a/python/samples/02-agents/context_providers/README.md b/python/samples/02-agents/context_providers/README.md index 70b2fdb8ff..442ebce61a 100644 --- a/python/samples/02-agents/context_providers/README.md +++ b/python/samples/02-agents/context_providers/README.md @@ -136,37 +136,38 @@ Different agents with isolated or shared memory configurations. ## Building Custom Context Providers -To create a custom context provider, implement the `ContextProvider` protocol: +To create a custom context provider, extend `BaseContextProvider`: ```python -from agent_framework import ContextProvider, Context, Message -from collections.abc import MutableSequence, Sequence +from agent_framework import AgentSession, BaseContextProvider, SessionContext, Message from typing import Any -class MyContextProvider(ContextProvider): - async def invoking( +class MyContextProvider(BaseContextProvider): + def __init__(self): + super().__init__("my-context") + + async def before_run( self, - messages: Message | MutableSequence[Message], - **kwargs: Any - ) -> Context: + *, + agent: Any, + session: AgentSession | None, + context: SessionContext, + state: dict[str, Any], + ) -> None: """Provide context before the agent processes the request.""" - # Return additional instructions, messages, or context - return Context(instructions="Additional instructions here") + context.extend_messages(self.source_id, [Message("system", ["Additional instructions here"])]) - async def invoked( + async def after_run( self, - request_messages: Message | Sequence[Message], - response_messages: Message | Sequence[Message] | None = None, - invoke_exception: Exception | None = None, - **kwargs: Any, + *, + agent: Any, + session: AgentSession | None, + context: SessionContext, + state: dict[str, Any], ) -> None: """Process the response after the agent generates it.""" # Store information, update memory, etc. pass - - def serialize(self) -> str: - """Serialize the provider state for persistence.""" - return "{}" ``` See `simple_context_provider.py` for a complete example. diff --git a/python/samples/02-agents/context_providers/aggregate_context_provider.py b/python/samples/02-agents/context_providers/aggregate_context_provider.py index 4e5cfb72aa..3f8e346b34 100644 --- a/python/samples/02-agents/context_providers/aggregate_context_provider.py +++ b/python/samples/02-agents/context_providers/aggregate_context_provider.py @@ -13,11 +13,10 @@ import asyncio from typing import Any -from agent_framework import Agent, AgentSession, BaseContextProvider, Message, SessionContext +from agent_framework import Agent, AgentSession, BaseContextProvider, SessionContext from agent_framework.azure import AzureAIClient from azure.identity.aio import AzureCliCredential - # region Example Context Providers diff --git a/python/samples/02-agents/context_providers/azure_ai_search/README.md b/python/samples/02-agents/context_providers/azure_ai_search/README.md index ecb00f68b4..49403d106c 100644 --- a/python/samples/02-agents/context_providers/azure_ai_search/README.md +++ b/python/samples/02-agents/context_providers/azure_ai_search/README.md @@ -144,7 +144,7 @@ async with AzureAIAgentClient(credential=DefaultAzureCredential()) as client: async with Agent( client=client, model=model_deployment, - context_provider=search_provider, + context_providers=[search_provider], ) as agent: response = await agent.run("What information is in the knowledge base?") ``` @@ -169,7 +169,7 @@ search_provider = AzureAISearchContextProvider( async with Agent( client=client, model=model_deployment, - context_provider=search_provider, + context_providers=[search_provider], ) as agent: response = await agent.run("Analyze and compare topics across documents") ``` diff --git a/python/samples/02-agents/context_providers/redis/README.md b/python/samples/02-agents/context_providers/redis/README.md index e0fde57bf2..dec2c77485 100644 --- a/python/samples/02-agents/context_providers/redis/README.md +++ b/python/samples/02-agents/context_providers/redis/README.md @@ -8,9 +8,9 @@ This folder contains an example demonstrating how to use the Redis context provi | File | Description | |------|-------------| -| [`azure_redis_conversation.py`](azure_redis_conversation.py) | Demonstrates conversation persistence with RedisChatMessageStore and Azure Redis with Azure AD (Entra ID) authentication using credential provider. | +| [`azure_redis_conversation.py`](azure_redis_conversation.py) | Demonstrates conversation persistence with RedisHistoryProvider and Azure Redis with Azure AD (Entra ID) authentication using credential provider. | | [`redis_basics.py`](redis_basics.py) | Shows standalone provider usage and agent integration. Demonstrates writing messages to Redis, retrieving context via full‑text or hybrid vector search, and persisting preferences across threads. Also includes a simple tool example whose outputs are remembered. | -| [`redis_conversation.py`](redis_conversation.py) | Simple example showing conversation persistence with RedisChatMessageStore using traditional connection string authentication. | +| [`redis_conversation.py`](redis_conversation.py) | Simple example showing conversation persistence with RedisContextProvider using traditional connection string authentication. | | [`redis_threads.py`](redis_threads.py) | Demonstrates thread scoping. Includes: (1) global thread scope with a fixed `thread_id` shared across operations; (2) per‑operation thread scope where `scope_to_per_operation_thread_id=True` binds memory to a single thread for the provider's lifetime; and (3) multiple agents with isolated memory via different `agent_id` values. | diff --git a/python/samples/02-agents/context_providers/redis/redis_basics.py b/python/samples/02-agents/context_providers/redis/redis_basics.py index ba038096db..81238eb171 100644 --- a/python/samples/02-agents/context_providers/redis/redis_basics.py +++ b/python/samples/02-agents/context_providers/redis/redis_basics.py @@ -32,7 +32,7 @@ from agent_framework import Message, tool from agent_framework.openai import OpenAIChatClient -from agent_framework_redis._context_provider import RedisContextProvider +from agent_framework.redis import RedisContextProvider from redisvl.extensions.cache.embeddings import EmbeddingsCache from redisvl.utils.vectorize import OpenAITextVectorizer diff --git a/python/samples/02-agents/context_providers/redis/redis_conversation.py b/python/samples/02-agents/context_providers/redis/redis_conversation.py index 6de659aba3..2d345d9930 100644 --- a/python/samples/02-agents/context_providers/redis/redis_conversation.py +++ b/python/samples/02-agents/context_providers/redis/redis_conversation.py @@ -18,7 +18,7 @@ import os from agent_framework.openai import OpenAIChatClient -from agent_framework_redis._context_provider import RedisContextProvider +from agent_framework.redis import RedisContextProvider from redisvl.extensions.cache.embeddings import EmbeddingsCache from redisvl.utils.vectorize import OpenAITextVectorizer diff --git a/python/samples/02-agents/context_providers/redis/redis_threads.py b/python/samples/02-agents/context_providers/redis/redis_threads.py index c11823dfb6..34179048d9 100644 --- a/python/samples/02-agents/context_providers/redis/redis_threads.py +++ b/python/samples/02-agents/context_providers/redis/redis_threads.py @@ -31,7 +31,7 @@ import uuid from agent_framework.openai import OpenAIChatClient -from agent_framework_redis._context_provider import RedisContextProvider +from agent_framework.redis import RedisContextProvider from redisvl.extensions.cache.embeddings import EmbeddingsCache from redisvl.utils.vectorize import OpenAITextVectorizer diff --git a/python/samples/02-agents/context_providers/simple_context_provider.py b/python/samples/02-agents/context_providers/simple_context_provider.py index 940e6a057e..fd2a7ce747 100644 --- a/python/samples/02-agents/context_providers/simple_context_provider.py +++ b/python/samples/02-agents/context_providers/simple_context_provider.py @@ -3,7 +3,7 @@ import asyncio from typing import Any -from agent_framework import Agent, AgentSession, BaseContextProvider, Message, SessionContext, SupportsChatGetResponse +from agent_framework import Agent, AgentSession, BaseContextProvider, SessionContext, SupportsChatGetResponse from agent_framework.azure import AzureAIClient from azure.identity.aio import AzureCliCredential from pydantic import BaseModel diff --git a/python/samples/02-agents/conversations/custom_chat_message_store_thread.py b/python/samples/02-agents/conversations/custom_chat_message_store_thread.py index 9470e7bbb2..e3ce5c5905 100644 --- a/python/samples/02-agents/conversations/custom_chat_message_store_thread.py +++ b/python/samples/02-agents/conversations/custom_chat_message_store_thread.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from typing import Any -from agent_framework import AgentSession, BaseHistoryProvider, Message, SessionContext +from agent_framework import AgentSession, BaseHistoryProvider, Message from agent_framework.openai import OpenAIChatClient """ diff --git a/python/samples/02-agents/providers/azure_ai/README.md b/python/samples/02-agents/providers/azure_ai/README.md index 55724e39fd..70e08b4fad 100644 --- a/python/samples/02-agents/providers/azure_ai/README.md +++ b/python/samples/02-agents/providers/azure_ai/README.md @@ -20,7 +20,7 @@ This folder contains examples demonstrating different ways to create and use age | [`azure_ai_with_code_interpreter_file_download.py`](azure_ai_with_code_interpreter_file_download.py) | Shows how to download files generated by code interpreter using the OpenAI containers API. | | [`azure_ai_with_content_filtering.py`](azure_ai_with_content_filtering.py) | Shows how to enable content filtering (RAI policy) on Azure AI agents using `RaiConfig`. Requires creating an RAI policy in Azure AI Foundry portal first. | | [`azure_ai_with_existing_agent.py`](azure_ai_with_existing_agent.py) | Shows how to work with a pre-existing agent by providing the agent name and version to the Azure AI client. Demonstrates agent reuse patterns for production scenarios. | -| [`azure_ai_with_existing_conversation.py`](azure_ai_with_existing_conversation.py) | Demonstrates how to use an existing conversation created on the service side with Azure AI agents. Shows two approaches: specifying conversation ID at the client level and using AgentThread with an existing conversation ID. | +| [`azure_ai_with_existing_conversation.py`](azure_ai_with_existing_conversation.py) | Demonstrates how to use an existing conversation created on the service side with Azure AI agents. Shows two approaches: specifying conversation ID at the client level and using AgentSession with an existing conversation ID. | | [`azure_ai_with_application_endpoint.py`](azure_ai_with_application_endpoint.py) | Demonstrates calling the Azure AI application-scoped endpoint. | | [`azure_ai_with_explicit_settings.py`](azure_ai_with_explicit_settings.py) | Shows how to create an agent with explicitly configured `AzureAIClient` settings, including project endpoint, model deployment, and credentials rather than relying on environment variable defaults. | | [`azure_ai_with_file_search.py`](azure_ai_with_file_search.py) | Shows how to use `AzureAIClient.get_file_search_tool()` with Azure AI agents to upload files, create vector stores, and enable agents to search through uploaded documents to answer user questions. | diff --git a/python/samples/02-agents/tools/function_tool_with_thread_injection.py b/python/samples/02-agents/tools/function_tool_with_session_injection.py similarity index 78% rename from python/samples/02-agents/tools/function_tool_with_thread_injection.py rename to python/samples/02-agents/tools/function_tool_with_session_injection.py index afd2bac555..bc89ca80ec 100644 --- a/python/samples/02-agents/tools/function_tool_with_thread_injection.py +++ b/python/samples/02-agents/tools/function_tool_with_session_injection.py @@ -4,7 +4,7 @@ from typing import Annotated, Any from agent_framework import AgentSession, tool -from agent_framework.openai import OpenAIChatClient +from agent_framework.openai import OpenAIResponsesClient from pydantic import Field """ @@ -25,16 +25,18 @@ async def get_weather( """Get the weather for a given location.""" # Get session object from kwargs session = kwargs.get("session") - if session and isinstance(session, AgentSession): - if session.service_session_id: - print(f"Session ID: {session.service_session_id}.") + if session and isinstance(session, AgentSession) and session.service_session_id: + print(f"Session ID: {session.service_session_id}.") return f"The weather in {location} is cloudy." async def main() -> None: - agent = OpenAIChatClient().as_agent( - name="WeatherAgent", instructions="You are a helpful weather assistant.", tools=[get_weather] + agent = OpenAIResponsesClient().as_agent( + name="WeatherAgent", + instructions="You are a helpful weather assistant.", + tools=[get_weather], + options={"store": True}, ) # Create a session diff --git a/python/samples/03-workflows/README.md b/python/samples/03-workflows/README.md index 9ee6f517e9..0c9f2c5df8 100644 --- a/python/samples/03-workflows/README.md +++ b/python/samples/03-workflows/README.md @@ -40,7 +40,7 @@ Once comfortable with these, explore the rest of the samples below. | Custom Agent Executors | [agents/custom_agent_executors.py](./agents/custom_agent_executors.py) | Create executors to handle agent run methods | | Workflow as Agent (Reflection Pattern) | [agents/workflow_as_agent_reflection_pattern.py](./agents/workflow_as_agent_reflection_pattern.py) | Wrap a workflow so it can behave like an agent (reflection pattern) | | Workflow as Agent + HITL | [agents/workflow_as_agent_human_in_the_loop.py](./agents/workflow_as_agent_human_in_the_loop.py) | Extend workflow-as-agent with human-in-the-loop capability | -| Workflow as Agent with Thread | [agents/workflow_as_agent_with_thread.py](./agents/workflow_as_agent_with_thread.py) | Use AgentThread to maintain conversation history across workflow-as-agent invocations | +| Workflow as Agent with Session | [agents/workflow_as_agent_with_thread.py](./agents/workflow_as_agent_with_thread.py) | Use AgentSession to maintain conversation history across workflow-as-agent invocations | | Workflow as Agent kwargs | [agents/workflow_as_agent_kwargs.py](./agents/workflow_as_agent_kwargs.py) | Pass custom context (data, user tokens) via kwargs through workflow.as_agent() to @ai_function tools | ### checkpoint diff --git a/python/samples/03-workflows/agents/azure_ai_agents_with_shared_thread.py b/python/samples/03-workflows/agents/azure_ai_agents_with_shared_thread.py index 988d3f539f..f33aa2ef10 100644 --- a/python/samples/03-workflows/agents/azure_ai_agents_with_shared_thread.py +++ b/python/samples/03-workflows/agents/azure_ai_agents_with_shared_thread.py @@ -7,7 +7,6 @@ AgentExecutor, AgentExecutorRequest, AgentExecutorResponse, - InMemoryHistoryProvider, WorkflowBuilder, WorkflowContext, WorkflowRunState, diff --git a/python/samples/03-workflows/checkpoint/workflow_as_agent_checkpoint.py b/python/samples/03-workflows/checkpoint/workflow_as_agent_checkpoint.py index 4b0b7e88b8..4fb9fbbe77 100644 --- a/python/samples/03-workflows/checkpoint/workflow_as_agent_checkpoint.py +++ b/python/samples/03-workflows/checkpoint/workflow_as_agent_checkpoint.py @@ -28,7 +28,6 @@ import os from agent_framework import ( - AgentSession, InMemoryCheckpointStorage, ) from agent_framework.azure import AzureOpenAIResponsesClient diff --git a/python/samples/05-end-to-end/hosted_agents/agent_with_text_search_rag/main.py b/python/samples/05-end-to-end/hosted_agents/agent_with_text_search_rag/main.py index 5c28917b51..e53430ec16 100644 --- a/python/samples/05-end-to-end/hosted_agents/agent_with_text_search_rag/main.py +++ b/python/samples/05-end-to-end/hosted_agents/agent_with_text_search_rag/main.py @@ -2,7 +2,6 @@ import json import sys -from collections.abc import MutableSequence from dataclasses import dataclass from typing import Any diff --git a/python/samples/autogen-migration/README.md b/python/samples/autogen-migration/README.md index 36010fa223..39e6afd582 100644 --- a/python/samples/autogen-migration/README.md +++ b/python/samples/autogen-migration/README.md @@ -52,7 +52,7 @@ python samples/autogen-migration/orchestrations/04_magentic_one.py ## Tips for Migration - **Default behavior differences**: AutoGen's `AssistantAgent` is single-turn by default (`max_tool_iterations=1`), while AF's `Agent` is multi-turn and continues tool execution automatically. -- **Thread management**: AF agents are stateless by default. Use `agent.get_new_thread()` and pass it to `run()` to maintain conversation state, similar to AutoGen's conversation context. +- **Thread management**: AF agents are stateless by default. Use `agent.create_session()` and pass it to `run()` to maintain conversation state, similar to AutoGen's conversation context. - **Tools**: AutoGen uses `FunctionTool` wrappers; AF uses `@tool` decorators with automatic schema inference. - **Orchestration patterns**: - `RoundRobinGroupChat` → `SequentialBuilder` or `WorkflowBuilder` diff --git a/python/samples/semantic-kernel-migration/README.md b/python/samples/semantic-kernel-migration/README.md index d04239a00d..6e6a135a0f 100644 --- a/python/samples/semantic-kernel-migration/README.md +++ b/python/samples/semantic-kernel-migration/README.md @@ -70,6 +70,6 @@ Swap the script path for any other workflow or process sample. Deactivate the sa ## Tips for Migration - Keep the original SK sample open while iterating on the AF equivalent; the code is intentionally formatted so you can copy/paste across SDKs. -- Threads/conversation state are explicit in AF. When porting SK code that relies on implicit thread reuse, call `agent.get_new_thread()` and pass it into each `run` call. +- Threads/conversation state are explicit in AF. When porting SK code that relies on implicit thread reuse, call `agent.create_session()` and pass it into each `run` call. - Tools map cleanly: SK `@kernel_function` plugins translate to AF `@tool` callables. Hosted tools (code interpreter, web search, MCP) are available only in AF—introduce them once parity is achieved. - For multi-agent orchestration, AF workflows expose checkpoints and resume capabilities that SK Process/Team abstractions do not. Use the workflow samples as a blueprint when modernizing complex agent graphs. From 111959a7f871a501f668e02f6bdeb1ccc86ffa53 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Feb 2026 15:24:49 +0100 Subject: [PATCH 12/28] fix: remove stale message_store, _notify_thread_of_new_messages, and session_id.key references in samples --- .../conversations/redis_chat_message_store_thread.py | 7 ++++++- .../azure_ai_agent/azure_ai_with_existing_thread.py | 3 +-- .../azure_openai/azure_chat_client_with_thread.py | 10 +++++----- .../02-agents/providers/custom/custom_agent.py | 12 ++++++++---- .../openai/openai_chat_client_with_thread.py | 12 +++++------- .../durabletask/03_single_agent_streaming/client.py | 2 +- 6 files changed, 26 insertions(+), 20 deletions(-) diff --git a/python/samples/02-agents/conversations/redis_chat_message_store_thread.py b/python/samples/02-agents/conversations/redis_chat_message_store_thread.py index 5f1a0371f8..f54edd8170 100644 --- a/python/samples/02-agents/conversations/redis_chat_message_store_thread.py +++ b/python/samples/02-agents/conversations/redis_chat_message_store_thread.py @@ -23,6 +23,7 @@ async def example_manual_memory_store() -> None: # Create Redis history provider redis_provider = RedisHistoryProvider( + source_id="redis_basic_chat", redis_url="redis://localhost:6379", ) @@ -60,6 +61,7 @@ async def example_user_session_management() -> None: # Create Redis history provider for specific user session redis_provider = RedisHistoryProvider( + source_id=f"redis_{user_id}", redis_url="redis://localhost:6379", max_messages=10, # Keep only last 10 messages ) @@ -72,7 +74,7 @@ async def example_user_session_management() -> None: ) # Start conversation - session = agent.create_session() + session = agent.create_session(session_id=session_id) print(f"Started session for user {user_id}") @@ -100,6 +102,7 @@ async def example_conversation_persistence() -> None: # Phase 1: Start conversation print("--- Phase 1: Starting conversation ---") redis_provider = RedisHistoryProvider( + source_id="redis_persistent_chat", redis_url="redis://localhost:6379", ) @@ -148,6 +151,7 @@ async def example_session_serialization() -> None: print("=== Session Serialization Example ===") redis_provider = RedisHistoryProvider( + source_id="redis_serialization_chat", redis_url="redis://localhost:6379", ) @@ -189,6 +193,7 @@ async def example_message_limits() -> None: # Create provider with small message limit redis_provider = RedisHistoryProvider( + source_id="redis_limited_chat", redis_url="redis://localhost:6379", max_messages=3, # Keep only 3 most recent messages ) diff --git a/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_existing_thread.py b/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_existing_thread.py index 64b736074a..7a1b15259b 100644 --- a/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_existing_thread.py +++ b/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_existing_thread.py @@ -51,8 +51,7 @@ async def main() -> None: tools=get_weather, ) - session = agent.create_session(service_session_id=created_thread.id) - assert session.is_initialized + session = agent.get_session(service_session_id=created_thread.id) result = await agent.run("What's the weather like in Tokyo?", session=session) print(f"Result: {result}\n") finally: diff --git a/python/samples/02-agents/providers/azure_openai/azure_chat_client_with_thread.py b/python/samples/02-agents/providers/azure_openai/azure_chat_client_with_thread.py index 1382a14843..04fdea8162 100644 --- a/python/samples/02-agents/providers/azure_openai/azure_chat_client_with_thread.py +++ b/python/samples/02-agents/providers/azure_openai/azure_chat_client_with_thread.py @@ -111,10 +111,11 @@ async def example_with_existing_session_messages() -> None: result1 = await agent.run(query1, session=session) print(f"Agent: {result1.text}") - # The session now contains the conversation history in memory - if session.message_store: - messages = await session.message_store.list_messages() - print(f"Session contains {len(messages or [])} messages") + # The session now contains the conversation history in state + memory_state = session.state.get("memory", {}) + messages = memory_state.get("messages", []) + if messages: + print(f"Session contains {len(messages)} messages") print("\n--- Continuing with the same session in a new agent instance ---") @@ -135,7 +136,6 @@ async def example_with_existing_session_messages() -> None: print("\n--- Alternative: Creating a new session from existing messages ---") # You can also create a new session from existing messages - messages = await session.message_store.list_messages() if session.message_store else [] new_session = AgentSession() query3 = "How does the Paris weather compare to London?" diff --git a/python/samples/02-agents/providers/custom/custom_agent.py b/python/samples/02-agents/providers/custom/custom_agent.py index 14626388bd..1d3e5577d4 100644 --- a/python/samples/02-agents/providers/custom/custom_agent.py +++ b/python/samples/02-agents/providers/custom/custom_agent.py @@ -105,9 +105,11 @@ async def _run( response_message = Message(role=Role.ASSISTANT, contents=[Content.from_text(text=echo_text)]) - # Notify the session of new messages if provided + # Store messages in session state if provided if session is not None: - await self._notify_thread_of_new_messages(session, normalized_messages, response_message) + stored = session.state.setdefault("memory", {}).setdefault("messages", []) + stored.extend(normalized_messages) + stored.append(response_message) return AgentResponse(messages=[response_message]) @@ -146,10 +148,12 @@ async def _run_stream( # Small delay to simulate streaming await asyncio.sleep(0.1) - # Notify the session of the complete response if provided + # Store messages in session state if provided if session is not None: complete_response = Message(role=Role.ASSISTANT, contents=[Content.from_text(text=response_text)]) - await self._notify_thread_of_new_messages(session, normalized_messages, complete_response) + stored = session.state.setdefault("memory", {}).setdefault("messages", []) + stored.extend(normalized_messages) + stored.append(complete_response) async def main() -> None: diff --git a/python/samples/02-agents/providers/openai/openai_chat_client_with_thread.py b/python/samples/02-agents/providers/openai/openai_chat_client_with_thread.py index ea225d80f6..6470b3a815 100644 --- a/python/samples/02-agents/providers/openai/openai_chat_client_with_thread.py +++ b/python/samples/02-agents/providers/openai/openai_chat_client_with_thread.py @@ -104,10 +104,11 @@ async def example_with_existing_session_messages() -> None: result1 = await agent.run(query1, session=session) print(f"Agent: {result1.text}") - # The session now contains the conversation history in memory - if session.message_store: - messages = await session.message_store.list_messages() - print(f"Session contains {len(messages or [])} messages") + # The session now contains the conversation history in state + memory_state = session.state.get("memory", {}) + messages = memory_state.get("messages", []) + if messages: + print(f"Session contains {len(messages)} messages") print("\n--- Continuing with the same session in a new agent instance ---") @@ -127,9 +128,6 @@ async def example_with_existing_session_messages() -> None: print("\n--- Alternative: Creating a new session from existing messages ---") - # You can also create a new session from existing messages - messages = await session.message_store.list_messages() if session.message_store else [] - new_session = AgentSession() query3 = "How does the Paris weather compare to London?" diff --git a/python/samples/04-hosting/durabletask/03_single_agent_streaming/client.py b/python/samples/04-hosting/durabletask/03_single_agent_streaming/client.py index ab0d82ff41..9cb6f4cd88 100644 --- a/python/samples/04-hosting/durabletask/03_single_agent_streaming/client.py +++ b/python/samples/04-hosting/durabletask/03_single_agent_streaming/client.py @@ -146,7 +146,7 @@ def run_client(agent_client: DurableAIAgentClient) -> None: logger.error("Failed to create a new session with session ID!") return - key = session.session_id.key + key = session.session_id logger.info(f"Session ID: {key}") # Get user input From f60b19a4a1d1593288ae137888462a52398f5e4e Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Feb 2026 15:32:25 +0100 Subject: [PATCH 13/28] refactor: merge context_providers and sessions sample folders into sessions, remove aggregate_context_provider --- .../02-agents/context_providers/README.md | 180 ------------ .../aggregate_context_provider.py | 153 ---------- .../02-agents/providers/azure_ai/README.md | 4 +- python/samples/README.md | 1 - .../getting_started/sessions/README.md | 103 +++++++ .../sessions/azure_ai_search/README.md | 264 ++++++++++++++++++ .../azure_ai_with_search_context_agentic.py | 141 ++++++++++ .../azure_ai_with_search_context_semantic.py | 97 +++++++ .../custom_chat_message_store_thread.py | 85 ++++++ .../getting_started/sessions/mem0/README.md | 55 ++++ .../sessions/mem0/mem0_basic.py | 82 ++++++ .../getting_started/sessions/mem0/mem0_oss.py | 79 ++++++ .../sessions/mem0/mem0_threads.py | 167 +++++++++++ .../getting_started/sessions/redis/README.md | 113 ++++++++ .../redis/azure_redis_conversation.py | 123 ++++++++ .../sessions/redis/redis_basics.py | 256 +++++++++++++++++ .../sessions/redis/redis_conversation.py | 105 +++++++ .../sessions/redis/redis_threads.py | 249 +++++++++++++++++ .../redis_chat_message_store_thread.py | 257 +++++++++++++++++ .../sessions/simple_context_provider.py | 129 +++++++++ .../sessions/suspend_resume_thread.py | 93 ++++++ 21 files changed, 2400 insertions(+), 336 deletions(-) delete mode 100644 python/samples/02-agents/context_providers/README.md delete mode 100644 python/samples/02-agents/context_providers/aggregate_context_provider.py create mode 100644 python/samples/getting_started/sessions/README.md create mode 100644 python/samples/getting_started/sessions/azure_ai_search/README.md create mode 100644 python/samples/getting_started/sessions/azure_ai_search/azure_ai_with_search_context_agentic.py create mode 100644 python/samples/getting_started/sessions/azure_ai_search/azure_ai_with_search_context_semantic.py create mode 100644 python/samples/getting_started/sessions/custom_chat_message_store_thread.py create mode 100644 python/samples/getting_started/sessions/mem0/README.md create mode 100644 python/samples/getting_started/sessions/mem0/mem0_basic.py create mode 100644 python/samples/getting_started/sessions/mem0/mem0_oss.py create mode 100644 python/samples/getting_started/sessions/mem0/mem0_threads.py create mode 100644 python/samples/getting_started/sessions/redis/README.md create mode 100644 python/samples/getting_started/sessions/redis/azure_redis_conversation.py create mode 100644 python/samples/getting_started/sessions/redis/redis_basics.py create mode 100644 python/samples/getting_started/sessions/redis/redis_conversation.py create mode 100644 python/samples/getting_started/sessions/redis/redis_threads.py create mode 100644 python/samples/getting_started/sessions/redis_chat_message_store_thread.py create mode 100644 python/samples/getting_started/sessions/simple_context_provider.py create mode 100644 python/samples/getting_started/sessions/suspend_resume_thread.py diff --git a/python/samples/02-agents/context_providers/README.md b/python/samples/02-agents/context_providers/README.md deleted file mode 100644 index 442ebce61a..0000000000 --- a/python/samples/02-agents/context_providers/README.md +++ /dev/null @@ -1,180 +0,0 @@ -# Context Provider Examples - -Context providers enable agents to maintain memory, retrieve relevant information, and enhance conversations with external context. The Agent Framework supports various context providers for different use cases, from simple in-memory storage to advanced persistent solutions with search capabilities. - -This folder contains examples demonstrating how to use different context providers with the Agent Framework. - -## Overview - -Context providers implement two key methods: - -- **`invoking`**: Called before the agent processes a request. Provides additional context, instructions, or retrieved information to enhance the agent's response. -- **`invoked`**: Called after the agent generates a response. Allows for storing information, updating memory, or performing post-processing. - -## Examples - -### Simple Context Provider - -| File | Description | Installation | -|------|-------------|--------------| -| [`simple_context_provider.py`](simple_context_provider.py) | Demonstrates building a custom context provider that extracts and stores user information (name and age) from conversations. Shows how to use structured output to extract data and provide dynamic instructions based on stored context. | No additional package required - uses core `agent-framework` | - -**Install:** -```bash -pip install agent-framework-azure-ai -``` - -### Azure AI Search - -| File | Description | -|------|-------------| -| [`azure_ai_search/azure_ai_with_search_context_agentic.py`](azure_ai_search/azure_ai_with_search_context_agentic.py) | **Agentic mode** (recommended for most scenarios): Uses Knowledge Bases in Azure AI Search for query planning and multi-hop reasoning. Provides more accurate results through intelligent retrieval. Slightly slower with more token consumption. | -| [`azure_ai_search/azure_ai_with_search_context_semantic.py`](azure_ai_search/azure_ai_with_search_context_semantic.py) | **Semantic mode** (fast queries): Fast hybrid search combining vector and keyword search with semantic ranking. Best for scenarios where speed is critical. | - -**Install:** -```bash -pip install agent-framework-azure-ai-search agent-framework-azure-ai -``` - -**Prerequisites:** -- Azure AI Search service with a search index -- Azure AI Foundry project with a model deployment -- For agentic mode: Azure OpenAI resource for Knowledge Base model calls -- Environment variables: `AZURE_SEARCH_ENDPOINT`, `AZURE_SEARCH_INDEX_NAME`, `AZURE_AI_PROJECT_ENDPOINT` - -**Key Concepts:** -- **Agentic mode**: Intelligent retrieval with multi-hop reasoning, better for complex queries -- **Semantic mode**: Fast hybrid search with semantic ranking, better for simple queries and speed - -### Mem0 - -The [mem0](mem0/) folder contains examples using Mem0, a self-improving memory layer that enables applications to have long-term memory capabilities. - -| File | Description | -|------|-------------| -| [`mem0/mem0_basic.py`](mem0/mem0_basic.py) | Basic example storing and retrieving user preferences across different conversation threads. | -| [`mem0/mem0_threads.py`](mem0/mem0_threads.py) | Advanced thread scoping strategies: global scope (memories shared), per-operation scope (memories isolated), and multiple agents with different memory configurations. | -| [`mem0/mem0_oss.py`](mem0/mem0_oss.py) | Using Mem0 Open Source self-hosted version as the context provider. | - -**Install:** -```bash -pip install agent-framework-mem0 -``` - -**Prerequisites:** -- Mem0 API key from [app.mem0.ai](https://app.mem0.ai/) OR self-host [Mem0 Open Source](https://docs.mem0.ai/open-source/overview) -- For Mem0 Platform: `MEM0_API_KEY` environment variable -- For Mem0 OSS: `OPENAI_API_KEY` for embedding generation - -**Key Concepts:** -- **Global Scope**: Memories shared across all conversation threads -- **Thread Scope**: Memories isolated per conversation thread -- **Memory Association**: Records can be associated with `user_id`, `agent_id`, `thread_id`, or `application_id` - -See the [mem0 README](mem0/README.md) for detailed documentation. - -### Redis - -The [redis](redis/) folder contains examples using Redis (RediSearch) for persistent, searchable memory with full-text and optional hybrid vector search. - -| File | Description | -|------|-------------| -| [`redis/redis_basics.py`](redis/redis_basics.py) | Standalone provider usage and agent integration. Demonstrates writing messages, full-text/hybrid search, persisting preferences, and tool output memory. | -| [`redis/redis_conversation.py`](redis/redis_conversation.py) | Conversational examples showing memory persistence across sessions. | -| [`redis/redis_threads.py`](redis/redis_threads.py) | Thread scoping: global scope, per-operation scope, and multiple agents with isolated memory via different `agent_id` values. | - -**Install:** -```bash -pip install agent-framework-redis -``` - -**Prerequisites:** -- Running Redis with RediSearch (Redis Stack or managed service) - - **Docker**: `docker run --name redis -p 6379:6379 -d redis:8.0.3` - - **Redis Cloud**: [redis.io/cloud](https://redis.io/cloud/) - - **Azure Managed Redis**: [Azure quickstart](https://learn.microsoft.com/azure/redis/quickstart-create-managed-redis) -- Optional: `OPENAI_API_KEY` for vector embeddings (hybrid search) - -**Key Concepts:** -- **Full-text search**: Fast keyword-based retrieval -- **Hybrid vector search**: Optional embeddings for semantic search (`vectorizer_choice="openai"` or `"hf"`) -- **Memory scoping**: Partition by `application_id`, `agent_id`, `user_id`, or `thread_id` -- **Thread scoping**: `scope_to_per_operation_thread_id=True` isolates memory per operation - -See the [redis README](redis/README.md) for detailed documentation. - -## Choosing a Context Provider - -| Provider | Use Case | Persistence | Search | Complexity | -|----------|----------|-------------|--------|------------| -| **Simple/Custom** | Learning, prototyping, simple memory needs | No (in-memory) | No | Low | -| **Azure AI Search** | RAG, document search, enterprise knowledge bases | Yes | Hybrid + Semantic | Medium | -| **Mem0** | Long-term user memory, preferences, personalization | Yes (cloud/self-hosted) | Semantic | Low-Medium | -| **Redis** | Fast retrieval, session memory, full-text + vector search | Yes | Full-text + Hybrid | Medium | - -## Common Patterns - -### 1. User Preference Memory -Store and retrieve user preferences, settings, or personal information across sessions. -- **Examples**: `simple_context_provider.py`, `mem0/mem0_basic.py`, `redis/redis_basics.py` - -### 2. Document Retrieval (RAG) -Retrieve relevant documents or knowledge base articles to answer questions. -- **Examples**: `azure_ai_search/azure_ai_with_search_context_*.py` - -### 3. Conversation History -Maintain conversation context across multiple turns and sessions. -- **Examples**: `redis/redis_conversation.py`, `mem0/mem0_threads.py` - -### 4. Thread Scoping -Isolate memory per conversation thread or share globally across threads. -- **Examples**: `mem0/mem0_threads.py`, `redis/redis_threads.py` - -### 5. Multi-Agent Memory -Different agents with isolated or shared memory configurations. -- **Examples**: `mem0/mem0_threads.py`, `redis/redis_threads.py` - -## Building Custom Context Providers - -To create a custom context provider, extend `BaseContextProvider`: - -```python -from agent_framework import AgentSession, BaseContextProvider, SessionContext, Message -from typing import Any - -class MyContextProvider(BaseContextProvider): - def __init__(self): - super().__init__("my-context") - - async def before_run( - self, - *, - agent: Any, - session: AgentSession | None, - context: SessionContext, - state: dict[str, Any], - ) -> None: - """Provide context before the agent processes the request.""" - context.extend_messages(self.source_id, [Message("system", ["Additional instructions here"])]) - - async def after_run( - self, - *, - agent: Any, - session: AgentSession | None, - context: SessionContext, - state: dict[str, Any], - ) -> None: - """Process the response after the agent generates it.""" - # Store information, update memory, etc. - pass -``` - -See `simple_context_provider.py` for a complete example. - -## Additional Resources - -- [Agent Framework Documentation](https://github.com/microsoft/agent-framework) -- [Azure AI Search Documentation](https://learn.microsoft.com/azure/search/) -- [Mem0 Documentation](https://docs.mem0.ai/) -- [Redis Documentation](https://redis.io/docs/) diff --git a/python/samples/02-agents/context_providers/aggregate_context_provider.py b/python/samples/02-agents/context_providers/aggregate_context_provider.py deleted file mode 100644 index 3f8e346b34..0000000000 --- a/python/samples/02-agents/context_providers/aggregate_context_provider.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -""" -This sample demonstrates how to use multiple context providers with an agent. - -Context providers can be passed as a list to the agent's context_providers parameter. -Each provider is called in order during the agent's lifecycle, and their context -is combined automatically. - -You can use built-in providers or implement your own by extending BaseContextProvider. -""" - -import asyncio -from typing import Any - -from agent_framework import Agent, AgentSession, BaseContextProvider, SessionContext -from agent_framework.azure import AzureAIClient -from azure.identity.aio import AzureCliCredential - -# region Example Context Providers - - -class TimeContextProvider(BaseContextProvider): - """A simple context provider that adds time-related instructions.""" - - def __init__(self): - super().__init__("time") - - async def before_run( - self, - *, - agent: Any, - session: AgentSession | None, - context: SessionContext, - state: dict[str, Any], - ) -> None: - from datetime import datetime - - current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - context.extend_instructions(self.source_id, f"The current date and time is: {current_time}. ") - - -class PersonaContextProvider(BaseContextProvider): - """A context provider that adds a persona to the agent.""" - - def __init__(self, persona: str): - super().__init__("persona") - self.persona = persona - - async def before_run( - self, - *, - agent: Any, - session: AgentSession | None, - context: SessionContext, - state: dict[str, Any], - ) -> None: - context.extend_instructions(self.source_id, f"Your persona: {self.persona}. ") - - -class PreferencesContextProvider(BaseContextProvider): - """A context provider that adds user preferences.""" - - def __init__(self): - super().__init__("preferences") - self.preferences: dict[str, str] = {} - - async def before_run( - self, - *, - agent: Any, - session: AgentSession | None, - context: SessionContext, - state: dict[str, Any], - ) -> None: - if not self.preferences: - return - prefs_str = ", ".join(f"{k}: {v}" for k, v in self.preferences.items()) - context.extend_instructions(self.source_id, f"User preferences: {prefs_str}. ") - - async def after_run( - self, - *, - agent: Any, - session: AgentSession | None, - context: SessionContext, - state: dict[str, Any], - ) -> None: - # Simple example: extract and store preferences from user messages - # In a real implementation, you might use structured extraction - request_messages = context.get_messages() - - for msg in request_messages: - content = msg.text if hasattr(msg, "text") else "" - # Very simple extraction - in production, use LLM-based extraction - if isinstance(content, str) and "prefer" in content.lower() and ":" in content: - parts = content.split(":") - if len(parts) >= 2: - key = parts[0].strip().lower().replace("i prefer ", "") - value = parts[1].strip() - self.preferences[key] = value - - -# endregion - - -# region Main - - -async def main(): - """Demonstrate using multiple context providers with an agent.""" - async with AzureCliCredential() as credential: - client = AzureAIClient(credential=credential) - - # Create individual context providers - time_provider = TimeContextProvider() - persona_provider = PersonaContextProvider("You are a helpful and friendly AI assistant named Max.") - preferences_provider = PreferencesContextProvider() - - # Create the agent with multiple context providers - async with Agent( - client=client, - instructions="You are a helpful assistant.", - context_providers=[ - time_provider, - persona_provider, - preferences_provider, - ], - ) as agent: - # Create a new session for the conversation - session = agent.create_session() - - # First message - the agent should include time and persona context - print("User: Hello! Who are you?") - result = await agent.run("Hello! Who are you?", session=session) - print(f"Agent: {result}\n") - - # Set a preference - print("User: I prefer language: formal English") - result = await agent.run("I prefer language: formal English", session=session) - print(f"Agent: {result}\n") - - # Ask something - the agent should now include the preference - print("User: Can you tell me a fun fact?") - result = await agent.run("Can you tell me a fun fact?", session=session) - print(f"Agent: {result}\n") - - # Show what the aggregate provider is tracking - print(f"\nPreferences tracked: {preferences_provider.preferences}") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/python/samples/02-agents/providers/azure_ai/README.md b/python/samples/02-agents/providers/azure_ai/README.md index 70e08b4fad..f7dfb0f8ce 100644 --- a/python/samples/02-agents/providers/azure_ai/README.md +++ b/python/samples/02-agents/providers/azure_ai/README.md @@ -28,8 +28,8 @@ This folder contains examples demonstrating different ways to create and use age | [`azure_ai_with_local_mcp.py`](azure_ai_with_local_mcp.py) | Shows how to integrate local Model Context Protocol (MCP) tools with Azure AI agents. | | [`azure_ai_with_response_format.py`](azure_ai_with_response_format.py) | Shows how to use structured outputs (response format) with Azure AI agents using Pydantic models to enforce specific response schemas. | | [`azure_ai_with_runtime_json_schema.py`](azure_ai_with_runtime_json_schema.py) | Shows how to use structured outputs (response format) with Azure AI agents using a JSON schema to enforce specific response schemas. | -| [`azure_ai_with_search_context_agentic.py`](../../context_providers/azure_ai_search/azure_ai_with_search_context_agentic.py) | Shows how to use AzureAISearchContextProvider with agentic mode. Uses Knowledge Bases for multi-hop reasoning across documents with query planning. Recommended for most scenarios - slightly slower with more token consumption for query planning, but more accurate results. | -| [`azure_ai_with_search_context_semantic.py`](../../context_providers/azure_ai_search/azure_ai_with_search_context_semantic.py) | Shows how to use AzureAISearchContextProvider with semantic mode. Fast hybrid search with vector + keyword search and semantic ranking for RAG. Best for simple queries where speed is critical. | +| [`azure_ai_with_search_context_agentic.py`](../../sessions/azure_ai_search/azure_ai_with_search_context_agentic.py) | Shows how to use AzureAISearchContextProvider with agentic mode. Uses Knowledge Bases for multi-hop reasoning across documents with query planning. Recommended for most scenarios - slightly slower with more token consumption for query planning, but more accurate results. | +| [`azure_ai_with_search_context_semantic.py`](../../sessions/azure_ai_search/azure_ai_with_search_context_semantic.py) | Shows how to use AzureAISearchContextProvider with semantic mode. Fast hybrid search with vector + keyword search and semantic ranking for RAG. Best for simple queries where speed is critical. | | [`azure_ai_with_sharepoint.py`](azure_ai_with_sharepoint.py) | Shows how to use SharePoint grounding with Azure AI agents to search through SharePoint content and answer user questions with proper citations. Requires a SharePoint connection configured in your Azure AI project. | | [`azure_ai_with_thread.py`](azure_ai_with_thread.py) | Demonstrates thread management with Azure AI agents, including automatic thread creation for stateless conversations and explicit thread management for maintaining conversation context across multiple interactions. | | [`azure_ai_with_image_generation.py`](azure_ai_with_image_generation.py) | Shows how to use `AzureAIClient.get_image_generation_tool()` with Azure AI agents to generate images based on text prompts. | diff --git a/python/samples/README.md b/python/samples/README.md index 328ba9aa2e..36df843e6a 100644 --- a/python/samples/README.md +++ b/python/samples/README.md @@ -43,4 +43,3 @@ For Azure authentication, run `az login` before running samples. - [Agent Framework Documentation](https://learn.microsoft.com/agent-framework/) - [AGENTS.md](./AGENTS.md) — Structure documentation for maintainers - [SAMPLE_GUIDELINES.md](./SAMPLE_GUIDELINES.md) — Coding conventions for samples - diff --git a/python/samples/getting_started/sessions/README.md b/python/samples/getting_started/sessions/README.md new file mode 100644 index 0000000000..6910aab677 --- /dev/null +++ b/python/samples/getting_started/sessions/README.md @@ -0,0 +1,103 @@ +# Sessions & Context Provider Examples + +Sessions and context providers are the core building blocks for agent memory in the Agent Framework. Sessions hold conversation state across turns, while context providers add, retrieve, and persist context before and after each agent invocation. + +## Core Concepts + +- **`AgentSession`**: Lightweight state container holding a `session_id` and a mutable `state` dict. Pass to `agent.run()` to maintain conversation across turns. +- **`BaseContextProvider`**: Hook that runs `before_run` / `after_run` around each invocation. Use for injecting instructions, RAG context, or metadata. +- **`BaseHistoryProvider`**: Subclass of `BaseContextProvider` for conversation history storage. Implements `get_messages()` / `save_messages()` and handles load/store automatically. +- **`InMemoryHistoryProvider`**: Built-in provider storing messages in `session.state`. Auto-injected when no providers are configured. + +## Examples + +### Session Management + +| File | Description | +|------|-------------| +| [`suspend_resume_thread.py`](suspend_resume_thread.py) | Suspend and resume sessions via `to_dict()` / `from_dict()` — both service-managed (Azure AI) and in-memory (OpenAI). | +| [`custom_chat_message_store_thread.py`](custom_chat_message_store_thread.py) | Implement a custom `BaseHistoryProvider` with dict-based storage. Shows serialization/deserialization. | +| [`redis_chat_message_store_thread.py`](redis_chat_message_store_thread.py) | `RedisHistoryProvider` for persistent storage: basic usage, user sessions, persistence across restarts, serialization, and message trimming. | + +### Custom Context Providers + +| File | Description | +|------|-------------| +| [`simple_context_provider.py`](simple_context_provider.py) | Build a custom `BaseContextProvider` that extracts and stores user information using structured output, then provides dynamic instructions based on stored context. | + +### Azure AI Search + +| File | Description | +|------|-------------| +| [`azure_ai_search/azure_ai_with_search_context_agentic.py`](azure_ai_search/azure_ai_with_search_context_agentic.py) | **Agentic mode** — Knowledge Bases with query planning and multi-hop reasoning. | +| [`azure_ai_search/azure_ai_with_search_context_semantic.py`](azure_ai_search/azure_ai_with_search_context_semantic.py) | **Semantic mode** — fast hybrid search with semantic ranking. | + +### Mem0 + +| File | Description | +|------|-------------| +| [`mem0/mem0_basic.py`](mem0/mem0_basic.py) | Basic Mem0 integration for user preference memory. | +| [`mem0/mem0_threads.py`](mem0/mem0_threads.py) | Thread scoping: global scope, per-operation scope, and multi-agent isolation. | +| [`mem0/mem0_oss.py`](mem0/mem0_oss.py) | Mem0 Open Source (self-hosted) integration. | + +### Redis + +| File | Description | +|------|-------------| +| [`redis/redis_basics.py`](redis/redis_basics.py) | Standalone provider usage, full-text/hybrid search, preferences, and tool output memory. | +| [`redis/redis_conversation.py`](redis/redis_conversation.py) | Conversation persistence across sessions. | +| [`redis/redis_threads.py`](redis/redis_threads.py) | Thread scoping: global, per-operation, and multi-agent isolation. | +| [`redis/azure_redis_conversation.py`](redis/azure_redis_conversation.py) | Azure Managed Redis with Entra ID authentication. | + +## Choosing a Provider + +| Provider | Use Case | Persistence | Search | +|----------|----------|-------------|--------| +| **InMemoryHistoryProvider** | Prototyping, stateless apps | Session state only | No | +| **Custom BaseHistoryProvider** | File/DB-backed storage | Your choice | Your choice | +| **RedisHistoryProvider** | Fast persistent chat history | Yes (Redis) | No | +| **RedisContextProvider** | Searchable memory / RAG | Yes (Redis) | Full-text + Hybrid | +| **Mem0ContextProvider** | Long-term user memory | Yes (cloud/self-hosted) | Semantic | +| **AzureAISearchContextProvider** | Enterprise RAG | Yes (Azure) | Hybrid + Semantic | + +## Building Custom Providers + +### Custom Context Provider + +```python +from agent_framework import AgentSession, BaseContextProvider, SessionContext, Message +from typing import Any + +class MyContextProvider(BaseContextProvider): + def __init__(self): + super().__init__("my-context") + + async def before_run(self, *, agent: Any, session: AgentSession | None, + context: SessionContext, state: dict[str, Any]) -> None: + context.extend_messages(self.source_id, [Message("system", ["Extra context here"])]) + + async def after_run(self, *, agent: Any, session: AgentSession | None, + context: SessionContext, state: dict[str, Any]) -> None: + pass # Store information, update memory, etc. +``` + +### Custom History Provider + +```python +from agent_framework import BaseHistoryProvider, Message +from collections.abc import Sequence +from typing import Any + +class MyHistoryProvider(BaseHistoryProvider): + def __init__(self): + super().__init__("my-history") + + async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Message]: + ... # Load from your storage + + async def save_messages(self, session_id: str | None, + messages: Sequence[Message], **kwargs: Any) -> None: + ... # Persist to your storage +``` + +See `custom_chat_message_store_thread.py` and `simple_context_provider.py` for complete examples. diff --git a/python/samples/getting_started/sessions/azure_ai_search/README.md b/python/samples/getting_started/sessions/azure_ai_search/README.md new file mode 100644 index 0000000000..49403d106c --- /dev/null +++ b/python/samples/getting_started/sessions/azure_ai_search/README.md @@ -0,0 +1,264 @@ +# Azure AI Search Context Provider Examples + +Azure AI Search context provider enables Retrieval Augmented Generation (RAG) with your agents by retrieving relevant documents from Azure AI Search indexes. It supports two search modes optimized for different use cases. + +This folder contains examples demonstrating how to use the Azure AI Search context provider with the Agent Framework. + +## Examples + +| File | Description | +|------|-------------| +| [`azure_ai_with_search_context_agentic.py`](azure_ai_with_search_context_agentic.py) | **Agentic mode** (recommended for most scenarios): Uses Knowledge Bases in Azure AI Search for query planning and multi-hop reasoning. Provides more accurate results through intelligent retrieval with automatic query reformulation. Slightly slower with more token consumption for query planning. [Learn more](https://techcommunity.microsoft.com/blog/azure-ai-foundry-blog/foundry-iq-boost-response-relevance-by-36-with-agentic-retrieval/4470720) | +| [`azure_ai_with_search_context_semantic.py`](azure_ai_with_search_context_semantic.py) | **Semantic mode** (fast queries): Fast hybrid search combining vector and keyword search with semantic ranking. Returns raw search results as context. Best for scenarios where speed is critical and simple retrieval is sufficient. | + +## Installation + +```bash +pip install agent-framework-azure-ai-search agent-framework-azure-ai +``` + +## Prerequisites + +### Required Resources + +1. **Azure AI Search service** with a search index containing your documents + - [Create Azure AI Search service](https://learn.microsoft.com/azure/search/search-create-service-portal) + - [Create and populate a search index](https://learn.microsoft.com/azure/search/search-what-is-an-index) + +2. **Azure AI Foundry project** with a model deployment + - [Create Azure AI Foundry project](https://learn.microsoft.com/azure/ai-studio/how-to/create-projects) + - Deploy a model (e.g., GPT-4o) + +3. **For Agentic mode only**: Azure OpenAI resource for Knowledge Base model calls + - [Create Azure OpenAI resource](https://learn.microsoft.com/azure/ai-services/openai/how-to/create-resource) + - Note: This is separate from your Azure AI Foundry project endpoint + +### Authentication + +Both examples support two authentication methods: + +- **API Key**: Set `AZURE_SEARCH_API_KEY` environment variable +- **Entra ID (Managed Identity)**: Uses `DefaultAzureCredential` when API key is not provided + +Run `az login` if using Entra ID authentication. + +## Configuration + +### Environment Variables + +**Common (both modes):** +- `AZURE_SEARCH_ENDPOINT`: Your Azure AI Search endpoint (e.g., `https://myservice.search.windows.net`) +- `AZURE_SEARCH_INDEX_NAME`: Name of your search index +- `AZURE_AI_PROJECT_ENDPOINT`: Your Azure AI Foundry project endpoint +- `AZURE_AI_MODEL_DEPLOYMENT_NAME`: Model deployment name (e.g., `gpt-4o`, defaults to `gpt-4o`) +- `AZURE_SEARCH_API_KEY`: _(Optional)_ Your search API key - if not provided, uses DefaultAzureCredential + +**Agentic mode only:** +- `AZURE_SEARCH_KNOWLEDGE_BASE_NAME`: Name of your Knowledge Base in Azure AI Search +- `AZURE_OPENAI_RESOURCE_URL`: Your Azure OpenAI resource URL (e.g., `https://myresource.openai.azure.com`) + - **Important**: This is different from `AZURE_AI_PROJECT_ENDPOINT` - Knowledge Base needs the OpenAI endpoint for model calls + +### Example .env file + +**For Semantic Mode:** +```env +AZURE_SEARCH_ENDPOINT=https://myservice.search.windows.net +AZURE_SEARCH_INDEX_NAME=my-index +AZURE_AI_PROJECT_ENDPOINT=https://.services.ai.azure.com/api/projects/ +AZURE_AI_MODEL_DEPLOYMENT_NAME=gpt-4o +# Optional - omit to use Entra ID +AZURE_SEARCH_API_KEY=your-search-key +``` + +**For Agentic Mode (add these to semantic mode variables):** +```env +AZURE_SEARCH_KNOWLEDGE_BASE_NAME=my-knowledge-base +AZURE_OPENAI_RESOURCE_URL=https://myresource.openai.azure.com +``` + +## Search Modes Comparison + +| Feature | Semantic Mode | Agentic Mode | +|---------|--------------|--------------| +| **Speed** | Fast | Slower (query planning overhead) | +| **Token Usage** | Lower | Higher (query reformulation) | +| **Retrieval Strategy** | Hybrid search + semantic ranking | Multi-hop reasoning with Knowledge Base | +| **Query Handling** | Direct search | Automatic query reformulation | +| **Best For** | Simple queries, speed-critical apps | Complex queries, multi-document reasoning | +| **Additional Setup** | None | Requires Knowledge Base + OpenAI resource | + +### When to Use Semantic Mode + +- **Simple queries** where direct keyword/vector search is sufficient +- **Speed is critical** and you need low latency +- **Straightforward retrieval** from single documents +- **Lower token costs** are important + +### When to Use Agentic Mode + +- **Complex queries** requiring multi-hop reasoning +- **Cross-document analysis** where information spans multiple sources +- **Ambiguous queries** that benefit from automatic reformulation +- **Higher accuracy** is more important than speed +- You need **intelligent query planning** and document synthesis + +## How the Examples Work + +### Semantic Mode Flow + +1. User query is sent to Azure AI Search +2. Hybrid search (vector + keyword) retrieves relevant documents +3. Semantic ranking reorders results for relevance +4. Top-k documents are returned as context +5. Agent generates response using retrieved context + +### Agentic Mode Flow + +1. User query is sent to the Knowledge Base +2. Knowledge Base plans the retrieval strategy +3. Multiple search queries may be executed (multi-hop) +4. Retrieved information is synthesized +5. Enhanced context is provided to the agent +6. Agent generates response with comprehensive context + +## Code Example + +### Semantic Mode + +```python +from agent_framework import Agent +from agent_framework.azure import AzureAIAgentClient, AzureAISearchContextProvider +from azure.identity.aio import DefaultAzureCredential + +# Create search provider with semantic mode (default) +search_provider = AzureAISearchContextProvider( + endpoint=search_endpoint, + index_name=index_name, + api_key=search_key, # Or use credential for Entra ID + mode="semantic", # Default mode + top_k=3, # Number of documents to retrieve +) + +# Create agent with search context +async with AzureAIAgentClient(credential=DefaultAzureCredential()) as client: + async with Agent( + client=client, + model=model_deployment, + context_providers=[search_provider], + ) as agent: + response = await agent.run("What information is in the knowledge base?") +``` + +### Agentic Mode + +```python +from agent_framework.azure import AzureAISearchContextProvider + +# Create search provider with agentic mode +search_provider = AzureAISearchContextProvider( + endpoint=search_endpoint, + index_name=index_name, + api_key=search_key, + mode="agentic", # Enable agentic retrieval + knowledge_base_name=knowledge_base_name, + azure_openai_resource_url=azure_openai_resource_url, + top_k=5, +) + +# Use with agent (same as semantic mode) +async with Agent( + client=client, + model=model_deployment, + context_providers=[search_provider], +) as agent: + response = await agent.run("Analyze and compare topics across documents") +``` + +## Running the Examples + +1. **Set up environment variables** (see Configuration section above) + +2. **Ensure you have an Azure AI Search index** with documents: + ```bash + # Verify your index exists + curl -X GET "https://myservice.search.windows.net/indexes/my-index?api-version=2024-07-01" \ + -H "api-key: YOUR_API_KEY" + ``` + +3. **For agentic mode**: Create a Knowledge Base in Azure AI Search + - [Knowledge Base documentation](https://learn.microsoft.com/azure/search/knowledge-store-create-portal) + +4. **Run the examples**: + ```bash + # Semantic mode (fast, simple) + python azure_ai_with_search_context_semantic.py + + # Agentic mode (intelligent, complex) + python azure_ai_with_search_context_agentic.py + ``` + +## Key Parameters + +### Common Parameters + +- `endpoint`: Azure AI Search service endpoint +- `index_name`: Name of the search index +- `api_key`: API key for authentication (optional, can use credential instead) +- `credential`: Azure credential for Entra ID auth (e.g., `DefaultAzureCredential()`) +- `mode`: Search mode - `"semantic"` (default) or `"agentic"` +- `top_k`: Number of documents to retrieve (default: 3 for semantic, 5 for agentic) + +### Semantic Mode Parameters + +- `semantic_configuration`: Name of semantic configuration in your index (optional) +- `query_type`: Query type - `"semantic"` for semantic search (default) + +### Agentic Mode Parameters + +- `knowledge_base_name`: Name of your Knowledge Base (required) +- `azure_openai_resource_url`: Azure OpenAI resource URL (required) +- `max_search_queries`: Maximum number of search queries to generate (default: 3) + +## Troubleshooting + +### Common Issues + +1. **Authentication errors** + - Ensure `AZURE_SEARCH_API_KEY` is set, or run `az login` for Entra ID auth + - Verify your credentials have search permissions + +2. **Index not found** + - Verify `AZURE_SEARCH_INDEX_NAME` matches your index name exactly + - Check that the index exists and contains documents + +3. **Agentic mode errors** + - Ensure `AZURE_SEARCH_KNOWLEDGE_BASE_NAME` is correctly configured + - Verify `AZURE_OPENAI_RESOURCE_URL` points to your Azure OpenAI resource (not AI Foundry endpoint) + - Check that your OpenAI resource has the necessary model deployments + +4. **No results returned** + - Verify your index has documents with vector embeddings (for semantic/hybrid search) + - Check that your queries match the content in your index + - Try increasing `top_k` parameter + +5. **Slow responses in agentic mode** + - This is expected - agentic mode trades speed for accuracy + - Reduce `max_search_queries` if needed + - Consider semantic mode for speed-critical applications + +## Performance Tips + +- **Use semantic mode** as the default for most scenarios - it's fast and effective +- **Switch to agentic mode** when you need multi-hop reasoning or complex queries +- **Adjust `top_k`** based on your needs - higher values provide more context but increase token usage +- **Enable semantic configuration** in your index for better semantic ranking +- **Use Entra ID authentication** in production for better security + +## Additional Resources + +- [Azure AI Search Documentation](https://learn.microsoft.com/azure/search/) +- [Azure AI Foundry Documentation](https://learn.microsoft.com/azure/ai-studio/) +- [RAG with Azure AI Search](https://learn.microsoft.com/azure/search/retrieval-augmented-generation-overview) +- [Semantic Search in Azure AI Search](https://learn.microsoft.com/azure/search/semantic-search-overview) +- [Knowledge Bases in Azure AI Search](https://learn.microsoft.com/azure/search/knowledge-store-concept-intro) +- [Agentic Retrieval Blog Post](https://techcommunity.microsoft.com/blog/azure-ai-foundry-blog/foundry-iq-boost-response-relevance-by-36-with-agentic-retrieval/4470720) diff --git a/python/samples/getting_started/sessions/azure_ai_search/azure_ai_with_search_context_agentic.py b/python/samples/getting_started/sessions/azure_ai_search/azure_ai_with_search_context_agentic.py new file mode 100644 index 0000000000..5a4503f920 --- /dev/null +++ b/python/samples/getting_started/sessions/azure_ai_search/azure_ai_with_search_context_agentic.py @@ -0,0 +1,141 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import os + +from agent_framework import Agent +from agent_framework.azure import AzureAIAgentClient, AzureAISearchContextProvider +from azure.identity.aio import AzureCliCredential +from dotenv import load_dotenv + +# Load environment variables from .env file +load_dotenv() + +""" +This sample demonstrates how to use Azure AI Search with agentic mode for RAG +(Retrieval Augmented Generation) with Azure AI agents. + +**Agentic mode** is recommended for most scenarios: +- Uses Knowledge Bases in Azure AI Search for query planning +- Performs multi-hop reasoning across documents +- Provides more accurate results through intelligent retrieval +- Slightly slower with more token consumption for query planning +- See: https://techcommunity.microsoft.com/blog/azure-ai-foundry-blog/foundry-iq-boost-response-relevance-by-36-with-agentic-retrieval/4470720 + +For simple queries where speed is critical, use semantic mode instead (see azure_ai_with_search_context_semantic.py). + +Prerequisites: +1. An Azure AI Search service +2. An Azure AI Foundry project with a model deployment +3. Either an existing Knowledge Base OR a search index (to auto-create a KB) + +Environment variables: + - AZURE_SEARCH_ENDPOINT: Your Azure AI Search endpoint + - AZURE_SEARCH_API_KEY: (Optional) API key - if not provided, uses DefaultAzureCredential + - AZURE_AI_PROJECT_ENDPOINT: Your Azure AI Foundry project endpoint + - AZURE_AI_MODEL_DEPLOYMENT_NAME: Your model deployment name (e.g., "gpt-4o") + +For using an existing Knowledge Base (recommended): + - AZURE_SEARCH_KNOWLEDGE_BASE_NAME: Your Knowledge Base name + +For auto-creating a Knowledge Base from an index: + - AZURE_SEARCH_INDEX_NAME: Your search index name + - AZURE_OPENAI_RESOURCE_URL: Azure OpenAI resource URL (e.g., "https://myresource.openai.azure.com") +""" + +# Sample queries to demonstrate agentic RAG +USER_INPUTS = [ + "What information is available in the knowledge base?", + "Analyze and compare the main topics from different documents", + "What connections can you find across different sections?", +] + + +async def main() -> None: + """Main function demonstrating Azure AI Search agentic mode.""" + + # Get configuration from environment + search_endpoint = os.environ["AZURE_SEARCH_ENDPOINT"] + search_key = os.environ.get("AZURE_SEARCH_API_KEY") + project_endpoint = os.environ["AZURE_AI_PROJECT_ENDPOINT"] + model_deployment = os.environ.get("AZURE_AI_MODEL_DEPLOYMENT_NAME", "gpt-4o") + + # Agentic mode requires exactly ONE of: knowledge_base_name OR index_name + # Option 1: Use existing Knowledge Base (recommended) + knowledge_base_name = os.environ.get("AZURE_SEARCH_KNOWLEDGE_BASE_NAME") + # Option 2: Auto-create KB from index (requires azure_openai_resource_url) + index_name = os.environ.get("AZURE_SEARCH_INDEX_NAME") + azure_openai_resource_url = os.environ.get("AZURE_OPENAI_RESOURCE_URL") + + # Create Azure AI Search context provider with agentic mode (recommended for accuracy) + print("Using AGENTIC mode (Knowledge Bases with query planning, recommended)\n") + print("This mode is slightly slower but provides more accurate results.\n") + + # Configure based on whether using existing KB or auto-creating from index + if knowledge_base_name: + # Use existing Knowledge Base - simplest approach + search_provider = AzureAISearchContextProvider( + endpoint=search_endpoint, + api_key=search_key, + credential=AzureCliCredential() if not search_key else None, + mode="agentic", + knowledge_base_name=knowledge_base_name, + # Optional: Configure retrieval behavior + knowledge_base_output_mode="extractive_data", # or "answer_synthesis" + retrieval_reasoning_effort="minimal", # or "medium", "low" + ) + else: + # Auto-create Knowledge Base from index + if not index_name: + raise ValueError("Set AZURE_SEARCH_KNOWLEDGE_BASE_NAME or AZURE_SEARCH_INDEX_NAME") + if not azure_openai_resource_url: + raise ValueError("AZURE_OPENAI_RESOURCE_URL required when using index_name") + search_provider = AzureAISearchContextProvider( + endpoint=search_endpoint, + index_name=index_name, + api_key=search_key, + credential=AzureCliCredential() if not search_key else None, + mode="agentic", + azure_openai_resource_url=azure_openai_resource_url, + model_deployment_name=model_deployment, + # Optional: Configure retrieval behavior + knowledge_base_output_mode="extractive_data", # or "answer_synthesis" + retrieval_reasoning_effort="minimal", # or "medium", "low" + top_k=3, + ) + + # Create agent with search context provider + async with ( + search_provider, + AzureAIAgentClient( + project_endpoint=project_endpoint, + model_deployment_name=model_deployment, + credential=AzureCliCredential(), + ) as client, + Agent( + client=client, + name="SearchAgent", + instructions=( + "You are a helpful assistant with advanced reasoning capabilities. " + "Use the provided context from the knowledge base to answer complex " + "questions that may require synthesizing information from multiple sources." + ), + context_providers=[search_provider], + ) as agent, + ): + print("=== Azure AI Agent with Search Context (Agentic Mode) ===\n") + + for user_input in USER_INPUTS: + print(f"User: {user_input}") + print("Agent: ", end="", flush=True) + + # Stream response + async for chunk in agent.run(user_input, stream=True): + if chunk.text: + print(chunk.text, end="", flush=True) + + print("\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/sessions/azure_ai_search/azure_ai_with_search_context_semantic.py b/python/samples/getting_started/sessions/azure_ai_search/azure_ai_with_search_context_semantic.py new file mode 100644 index 0000000000..8309d5197c --- /dev/null +++ b/python/samples/getting_started/sessions/azure_ai_search/azure_ai_with_search_context_semantic.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import os + +from agent_framework import Agent +from agent_framework.azure import AzureAIAgentClient, AzureAISearchContextProvider +from azure.identity.aio import AzureCliCredential +from dotenv import load_dotenv + +# Load environment variables from .env file +load_dotenv() + +""" +This sample demonstrates how to use Azure AI Search with semantic mode for RAG +(Retrieval Augmented Generation) with Azure AI agents. + +**Semantic mode** is the recommended default mode: +- Fast hybrid search combining vector and keyword search +- Uses semantic ranking for improved relevance +- Returns raw search results as context +- Best for most RAG use cases + +Prerequisites: +1. An Azure AI Search service with a search index +2. An Azure AI Foundry project with a model deployment +3. Set the following environment variables: + - AZURE_SEARCH_ENDPOINT: Your Azure AI Search endpoint + - AZURE_SEARCH_API_KEY: (Optional) Your search API key - if not provided, uses DefaultAzureCredential for Entra ID + - AZURE_SEARCH_INDEX_NAME: Your search index name + - AZURE_AI_PROJECT_ENDPOINT: Your Azure AI Foundry project endpoint + - AZURE_AI_MODEL_DEPLOYMENT_NAME: Your model deployment name (e.g., "gpt-4o") +""" + +# Sample queries to demonstrate RAG +USER_INPUTS = [ + "What information is available in the knowledge base?", + "Summarize the main topics from the documents", + "Find specific details about the content", +] + + +async def main() -> None: + """Main function demonstrating Azure AI Search semantic mode.""" + + # Get configuration from environment + search_endpoint = os.environ["AZURE_SEARCH_ENDPOINT"] + search_key = os.environ.get("AZURE_SEARCH_API_KEY") + index_name = os.environ["AZURE_SEARCH_INDEX_NAME"] + project_endpoint = os.environ["AZURE_AI_PROJECT_ENDPOINT"] + model_deployment = os.environ.get("AZURE_AI_MODEL_DEPLOYMENT_NAME", "gpt-4o") + + # Create Azure AI Search context provider with semantic mode (recommended, fast) + print("Using SEMANTIC mode (hybrid search + semantic ranking, fast)\n") + search_provider = AzureAISearchContextProvider( + endpoint=search_endpoint, + index_name=index_name, + api_key=search_key, # Use api_key for API key auth, or credential for managed identity + credential=AzureCliCredential() if not search_key else None, + mode="semantic", # Default mode + top_k=3, # Retrieve top 3 most relevant documents + ) + + # Create agent with search context provider + async with ( + search_provider, + AzureAIAgentClient( + project_endpoint=project_endpoint, + model_deployment_name=model_deployment, + credential=AzureCliCredential(), + ) as client, + Agent( + client=client, + name="SearchAgent", + instructions=( + "You are a helpful assistant. Use the provided context from the " + "knowledge base to answer questions accurately." + ), + context_providers=[search_provider], + ) as agent, + ): + print("=== Azure AI Agent with Search Context (Semantic Mode) ===\n") + + for user_input in USER_INPUTS: + print(f"User: {user_input}") + print("Agent: ", end="", flush=True) + + # Stream response + async for chunk in agent.run(user_input, stream=True): + if chunk.text: + print(chunk.text, end="", flush=True) + + print("\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/sessions/custom_chat_message_store_thread.py b/python/samples/getting_started/sessions/custom_chat_message_store_thread.py new file mode 100644 index 0000000000..e3ce5c5905 --- /dev/null +++ b/python/samples/getting_started/sessions/custom_chat_message_store_thread.py @@ -0,0 +1,85 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from collections.abc import Sequence +from typing import Any + +from agent_framework import AgentSession, BaseHistoryProvider, Message +from agent_framework.openai import OpenAIChatClient + +""" +Custom History Provider Example + +This sample demonstrates how to implement and use a custom history provider +for session management, allowing you to persist conversation history in your +preferred storage solution (database, file system, etc.). +""" + + +class CustomHistoryProvider(BaseHistoryProvider): + """Implementation of custom history provider. + In real applications, this can be an implementation of relational database or vector store.""" + + def __init__(self) -> None: + super().__init__("custom-history") + self._storage: dict[str, list[Message]] = {} + + async def get_messages( + self, session_id: str | None, *, state: dict[str, Any] | None = None, **kwargs: Any + ) -> list[Message]: + key = session_id or "default" + return list(self._storage.get(key, [])) + + async def save_messages( + self, + session_id: str | None, + messages: Sequence[Message], + *, + state: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + key = session_id or "default" + if key not in self._storage: + self._storage[key] = [] + self._storage[key].extend(messages) + + +async def main() -> None: + """Demonstrates how to use 3rd party or custom history provider for sessions.""" + print("=== Session with 3rd party or custom history provider ===") + + # OpenAI Chat Client is used as an example here, + # other chat clients can be used as well. + agent = OpenAIChatClient().as_agent( + name="CustomBot", + instructions="You are a helpful assistant that remembers our conversation.", + # Use custom history provider. + # If not provided, the default in-memory provider will be used. + context_providers=[CustomHistoryProvider()], + ) + + # Start a new session for the agent conversation. + session = agent.create_session() + + # Respond to user input. + query = "Hello! My name is Alice and I love pizza." + print(f"User: {query}") + print(f"Agent: {await agent.run(query, session=session)}\n") + + # Serialize the session state, so it can be stored for later use. + serialized_session = session.to_dict() + + # The session can now be saved to a database, file, or any other storage mechanism and loaded again later. + print(f"Serialized session: {serialized_session}\n") + + # Deserialize the session state after loading from storage. + resumed_session = AgentSession.from_dict(serialized_session) + + # Respond to user input. + query = "What do you remember about me?" + print(f"User: {query}") + print(f"Agent: {await agent.run(query, session=resumed_session)}\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/sessions/mem0/README.md b/python/samples/getting_started/sessions/mem0/README.md new file mode 100644 index 0000000000..61d8bbd51f --- /dev/null +++ b/python/samples/getting_started/sessions/mem0/README.md @@ -0,0 +1,55 @@ +# Mem0 Context Provider Examples + +[Mem0](https://mem0.ai/) is a self-improving memory layer for Large Language Models that enables applications to have long-term memory capabilities. The Agent Framework's Mem0 context provider integrates with Mem0's API to provide persistent memory across conversation sessions. + +This folder contains examples demonstrating how to use the Mem0 context provider with the Agent Framework for persistent memory and context management across conversations. + +## Examples + +| File | Description | +|------|-------------| +| [`mem0_basic.py`](mem0_basic.py) | Basic example of using Mem0 context provider to store and retrieve user preferences across different conversation threads. | +| [`mem0_threads.py`](mem0_threads.py) | Advanced example demonstrating different thread scoping strategies with Mem0. Covers global thread scope (memories shared across all operations), per-operation thread scope (memories isolated per thread), and multiple agents with different memory configurations for personal vs. work contexts. | +| [`mem0_oss.py`](mem0_oss.py) | Example of using the Mem0 Open Source self-hosted version as the context provider. Demonstrates setup and configuration for local deployment. | + +## Prerequisites + +### Required Resources + +1. [Mem0 API Key](https://app.mem0.ai/) - Sign up for a Mem0 account and get your API key - _or_ self-host [Mem0 Open Source](https://docs.mem0.ai/open-source/overview) +2. Azure AI project endpoint (used in these examples) +3. Azure CLI authentication (run `az login`) + +## Configuration + +### Environment Variables + +Set the following environment variables: + +**For Mem0 Platform:** +- `MEM0_API_KEY`: Your Mem0 API key (alternatively, pass it as `api_key` parameter to `Mem0Provider`). Not required if you are self-hosting [Mem0 Open Source](https://docs.mem0.ai/open-source/overview) + +**For Mem0 Open Source:** +- `OPENAI_API_KEY`: Your OpenAI API key (used by Mem0 OSS for embedding generation and automatic memory extraction) + +**For Azure AI:** +- `AZURE_AI_PROJECT_ENDPOINT`: Your Azure AI project endpoint +- `AZURE_AI_MODEL_DEPLOYMENT_NAME`: The name of your model deployment + +## Key Concepts + +### Memory Scoping + +The Mem0 context provider supports different scoping strategies: + +- **Global Scope** (`scope_to_per_operation_thread_id=False`): Memories are shared across all conversation threads +- **Thread Scope** (`scope_to_per_operation_thread_id=True`): Memories are isolated per conversation thread + +### Memory Association + +Mem0 records can be associated with different identifiers: + +- `user_id`: Associate memories with a specific user +- `agent_id`: Associate memories with a specific agent +- `thread_id`: Associate memories with a specific conversation thread +- `application_id`: Associate memories with an application context diff --git a/python/samples/getting_started/sessions/mem0/mem0_basic.py b/python/samples/getting_started/sessions/mem0/mem0_basic.py new file mode 100644 index 0000000000..f7a3a7f91f --- /dev/null +++ b/python/samples/getting_started/sessions/mem0/mem0_basic.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import uuid + +from agent_framework import tool +from agent_framework.azure import AzureAIAgentClient +from agent_framework.mem0 import Mem0ContextProvider +from azure.identity.aio import AzureCliCredential + + +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +@tool(approval_mode="never_require") +def retrieve_company_report(company_code: str, detailed: bool) -> str: + if company_code != "CNTS": + raise ValueError("Company code not found") + if not detailed: + return "CNTS is a company that specializes in technology." + return ( + "CNTS is a company that specializes in technology. " + "It had a revenue of $10 million in 2022. It has 100 employees." + ) + + +async def main() -> None: + """Example of memory usage with Mem0 context provider.""" + print("=== Mem0 Context Provider Example ===") + + # Each record in Mem0 should be associated with agent_id or user_id or application_id or thread_id. + # In this example, we associate Mem0 records with user_id. + user_id = str(uuid.uuid4()) + + # For Azure authentication, run `az login` command in terminal or replace AzureCliCredential with preferred + # authentication option. + # For Mem0 authentication, set Mem0 API key via "api_key" parameter or MEM0_API_KEY environment variable. + async with ( + AzureCliCredential() as credential, + AzureAIAgentClient(credential=credential).as_agent( + name="FriendlyAssistant", + instructions="You are a friendly assistant.", + tools=retrieve_company_report, + context_providers=[Mem0ContextProvider(user_id=user_id)], + ) as agent, + ): + # First ask the agent to retrieve a company report with no previous context. + # The agent will not be able to invoke the tool, since it doesn't know + # the company code or the report format, so it should ask for clarification. + query = "Please retrieve my company report" + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result}\n") + + # Now tell the agent the company code and the report format that you want to use + # and it should be able to invoke the tool and return the report. + query = "I always work with CNTS and I always want a detailed report format. Please remember and retrieve it." + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result}\n") + + # Mem0 processes and indexes memories asynchronously. + # Wait for memories to be indexed before querying in a new thread. + # In production, consider implementing retry logic or using Mem0's + # eventual consistency handling instead of a fixed delay. + print("Waiting for memories to be processed...") + await asyncio.sleep(12) # Empirically determined delay for Mem0 indexing + + print("\nRequest within a new session:") + # Create a new session for the agent. + # The new session has no context of the previous conversation. + session = agent.create_session() + + # Since we have the mem0 component in the session, the agent should be able to + # retrieve the company report without asking for clarification, as it will + # be able to remember the user preferences from Mem0 component. + query = "Please retrieve my company report" + print(f"User: {query}") + result = await agent.run(query, session=session) + print(f"Agent: {result}\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/sessions/mem0/mem0_oss.py b/python/samples/getting_started/sessions/mem0/mem0_oss.py new file mode 100644 index 0000000000..2178bbfe58 --- /dev/null +++ b/python/samples/getting_started/sessions/mem0/mem0_oss.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import uuid + +from agent_framework import tool +from agent_framework.azure import AzureAIAgentClient +from agent_framework.mem0 import Mem0ContextProvider +from azure.identity.aio import AzureCliCredential +from mem0 import AsyncMemory + + +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +@tool(approval_mode="never_require") +def retrieve_company_report(company_code: str, detailed: bool) -> str: + if company_code != "CNTS": + raise ValueError("Company code not found") + if not detailed: + return "CNTS is a company that specializes in technology." + return ( + "CNTS is a company that specializes in technology. " + "It had a revenue of $10 million in 2022. It has 100 employees." + ) + + +async def main() -> None: + """Example of memory usage with local Mem0 OSS context provider.""" + print("=== Mem0 Context Provider Example ===") + + # Each record in Mem0 should be associated with agent_id or user_id or application_id or thread_id. + # In this example, we associate Mem0 records with user_id. + user_id = str(uuid.uuid4()) + + # For Azure authentication, run `az login` command in terminal or replace AzureCliCredential with preferred + # authentication option. + # By default, local Mem0 authenticates to your OpenAI using the OPENAI_API_KEY environment variable. + # See the Mem0 documentation for other LLM providers and authentication options. + local_mem0_client = AsyncMemory() + async with ( + AzureCliCredential() as credential, + AzureAIAgentClient(credential=credential).as_agent( + name="FriendlyAssistant", + instructions="You are a friendly assistant.", + tools=retrieve_company_report, + context_providers=[Mem0ContextProvider(user_id=user_id, mem0_client=local_mem0_client)], + ) as agent, + ): + # First ask the agent to retrieve a company report with no previous context. + # The agent will not be able to invoke the tool, since it doesn't know + # the company code or the report format, so it should ask for clarification. + query = "Please retrieve my company report" + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result}\n") + + # Now tell the agent the company code and the report format that you want to use + # and it should be able to invoke the tool and return the report. + query = "I always work with CNTS and I always want a detailed report format. Please remember and retrieve it." + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result}\n") + + print("\nRequest within a new session:") + + # Create a new session for the agent. + # The new session has no context of the previous conversation. + session = agent.create_session() + + # Since we have the mem0 component in the session, the agent should be able to + # retrieve the company report without asking for clarification, as it will + # be able to remember the user preferences from Mem0 component. + query = "Please retrieve my company report" + print(f"User: {query}") + result = await agent.run(query, session=session) + print(f"Agent: {result}\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/sessions/mem0/mem0_threads.py b/python/samples/getting_started/sessions/mem0/mem0_threads.py new file mode 100644 index 0000000000..dd657b4e1d --- /dev/null +++ b/python/samples/getting_started/sessions/mem0/mem0_threads.py @@ -0,0 +1,167 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import uuid + +from agent_framework import tool +from agent_framework.azure import AzureAIAgentClient +from agent_framework.mem0 import Mem0ContextProvider +from azure.identity.aio import AzureCliCredential + + +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +@tool(approval_mode="never_require") +def get_user_preferences(user_id: str) -> str: + """Mock function to get user preferences.""" + preferences = { + "user123": "Prefers concise responses and technical details", + "user456": "Likes detailed explanations with examples", + } + return preferences.get(user_id, "No specific preferences found") + + +async def example_global_thread_scope() -> None: + """Example 1: Global thread_id scope (memories shared across all operations).""" + print("1. Global Thread Scope Example:") + print("-" * 40) + + global_thread_id = str(uuid.uuid4()) + user_id = "user123" + + async with ( + AzureCliCredential() as credential, + AzureAIAgentClient(credential=credential).as_agent( + name="GlobalMemoryAssistant", + instructions="You are an assistant that remembers user preferences across conversations.", + tools=get_user_preferences, + context_providers=[Mem0ContextProvider( + user_id=user_id, + thread_id=global_thread_id, + scope_to_per_operation_thread_id=False, # Share memories across all sessions + )], + ) as global_agent, + ): + # Store some preferences in the global scope + query = "Remember that I prefer technical responses with code examples when discussing programming." + print(f"User: {query}") + result = await global_agent.run(query) + print(f"Agent: {result}\n") + + # Create a new session - but memories should still be accessible due to global scope + new_session = global_agent.create_session() + query = "What do you know about my preferences?" + print(f"User (new session): {query}") + result = await global_agent.run(query, session=new_session) + print(f"Agent: {result}\n") + + +async def example_per_operation_thread_scope() -> None: + """Example 2: Per-operation thread scope (memories isolated per session). + + Note: When scope_to_per_operation_thread_id=True, the provider is bound to a single session + throughout its lifetime. Use the same session object for all operations with that provider. + """ + print("2. Per-Operation Thread Scope Example:") + print("-" * 40) + + user_id = "user123" + + async with ( + AzureCliCredential() as credential, + AzureAIAgentClient(credential=credential).as_agent( + name="ScopedMemoryAssistant", + instructions="You are an assistant with thread-scoped memory.", + tools=get_user_preferences, + context_providers=[Mem0ContextProvider( + user_id=user_id, + scope_to_per_operation_thread_id=True, # Isolate memories per session + )], + ) as scoped_agent, + ): + # Create a specific session for this scoped provider + dedicated_session = scoped_agent.create_session() + + # Store some information in the dedicated session + query = "Remember that for this conversation, I'm working on a Python project about data analysis." + print(f"User (dedicated session): {query}") + result = await scoped_agent.run(query, session=dedicated_session) + print(f"Agent: {result}\n") + + # Test memory retrieval in the same dedicated session + query = "What project am I working on?" + print(f"User (same dedicated session): {query}") + result = await scoped_agent.run(query, session=dedicated_session) + print(f"Agent: {result}\n") + + # Store more information in the same session + query = "Also remember that I prefer using pandas and matplotlib for this project." + print(f"User (same dedicated session): {query}") + result = await scoped_agent.run(query, session=dedicated_session) + print(f"Agent: {result}\n") + + # Test comprehensive memory retrieval + query = "What do you know about my current project and preferences?" + print(f"User (same dedicated session): {query}") + result = await scoped_agent.run(query, session=dedicated_session) + print(f"Agent: {result}\n") + + +async def example_multiple_agents() -> None: + """Example 3: Multiple agents with different thread configurations.""" + print("3. Multiple Agents with Different Thread Configurations:") + print("-" * 40) + + agent_id_1 = "agent_personal" + agent_id_2 = "agent_work" + + async with ( + AzureCliCredential() as credential, + AzureAIAgentClient(credential=credential).as_agent( + name="PersonalAssistant", + instructions="You are a personal assistant that helps with personal tasks.", + context_providers=[Mem0ContextProvider( + agent_id=agent_id_1, + )], + ) as personal_agent, + AzureAIAgentClient(credential=credential).as_agent( + name="WorkAssistant", + instructions="You are a work assistant that helps with professional tasks.", + context_providers=[Mem0ContextProvider( + agent_id=agent_id_2, + )], + ) as work_agent, + ): + # Store personal information + query = "Remember that I like to exercise at 6 AM and prefer outdoor activities." + print(f"User to Personal Agent: {query}") + result = await personal_agent.run(query) + print(f"Personal Agent: {result}\n") + + # Store work information + query = "Remember that I have team meetings every Tuesday at 2 PM." + print(f"User to Work Agent: {query}") + result = await work_agent.run(query) + print(f"Work Agent: {result}\n") + + # Test memory isolation + query = "What do you know about my schedule?" + print(f"User to Personal Agent: {query}") + result = await personal_agent.run(query) + print(f"Personal Agent: {result}\n") + + print(f"User to Work Agent: {query}") + result = await work_agent.run(query) + print(f"Work Agent: {result}\n") + + +async def main() -> None: + """Run all Mem0 thread management examples.""" + print("=== Mem0 Thread Management Example ===\n") + + await example_global_thread_scope() + await example_per_operation_thread_scope() + await example_multiple_agents() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/sessions/redis/README.md b/python/samples/getting_started/sessions/redis/README.md new file mode 100644 index 0000000000..dec2c77485 --- /dev/null +++ b/python/samples/getting_started/sessions/redis/README.md @@ -0,0 +1,113 @@ +# Redis Context Provider Examples + +The Redis context provider enables persistent, searchable memory for your agents using Redis (RediSearch). It supports full‑text search and optional hybrid search with vector embeddings, letting agents remember and retrieve user context across sessions and threads. + +This folder contains an example demonstrating how to use the Redis context provider with the Agent Framework. + +## Examples + +| File | Description | +|------|-------------| +| [`azure_redis_conversation.py`](azure_redis_conversation.py) | Demonstrates conversation persistence with RedisHistoryProvider and Azure Redis with Azure AD (Entra ID) authentication using credential provider. | +| [`redis_basics.py`](redis_basics.py) | Shows standalone provider usage and agent integration. Demonstrates writing messages to Redis, retrieving context via full‑text or hybrid vector search, and persisting preferences across threads. Also includes a simple tool example whose outputs are remembered. | +| [`redis_conversation.py`](redis_conversation.py) | Simple example showing conversation persistence with RedisContextProvider using traditional connection string authentication. | +| [`redis_threads.py`](redis_threads.py) | Demonstrates thread scoping. Includes: (1) global thread scope with a fixed `thread_id` shared across operations; (2) per‑operation thread scope where `scope_to_per_operation_thread_id=True` binds memory to a single thread for the provider's lifetime; and (3) multiple agents with isolated memory via different `agent_id` values. | + + +## Prerequisites + +### Required resources + +1. A running Redis with RediSearch (Redis Stack or a managed service) +2. Python environment with Agent Framework Redis extra installed +3. Optional: OpenAI API key if using vector embeddings + +### Install the package + +```bash +pip install "agent-framework-redis" +``` + +## Running Redis + +Pick one option: + +### Option A: Docker (local Redis Stack) + +```bash +docker run --name redis -p 6379:6379 -d redis:8.0.3 +``` + +### Option B: Redis Cloud + +Create a free database and get the connection URL at `https://redis.io/cloud/`. + +### Option C: Azure Managed Redis + +See quickstart: `https://learn.microsoft.com/azure/redis/quickstart-create-managed-redis` + +## Configuration + +### Environment variables + +- `OPENAI_API_KEY` (optional): Required only if you set `vectorizer_choice="openai"` to enable hybrid search. + +### Provider configuration highlights + +The provider supports both full‑text only and hybrid vector search: + +- Set `vectorizer_choice` to `"openai"` or `"hf"` to enable embeddings and hybrid search. +- When using a vectorizer, also set `vector_field_name` (e.g., `"vector"`). +- Partition fields for scoping memory: `application_id`, `agent_id`, `user_id`, `thread_id`. +- Thread scoping: `scope_to_per_operation_thread_id=True` isolates memory per operation thread. +- Index management: `index_name`, `overwrite_redis_index`, `drop_redis_index`. + +## What the example does + +`redis_basics.py` walks through three scenarios: + +1. Standalone provider usage: adds messages and retrieves context via `invoking`. +2. Agent integration: teaches the agent a preference and verifies it is remembered across turns. +3. Agent + tool: calls a sample tool (flight search) and then asks the agent to recall details remembered from the tool output. + +It uses OpenAI for both chat (via `OpenAIChatClient`) and, in some steps, optional embeddings for hybrid search. + +## How to run + +1) Start Redis (see options above). For local default, ensure it's reachable at `redis://localhost:6379`. + +2) Set your OpenAI key if using embeddings and for the chat client used in the sample: + +```bash +export OPENAI_API_KEY="" +``` + +3) Run the example: + +```bash +python redis_basics.py +``` + +You should see the agent responses and, when using embeddings, context retrieved from Redis. The example includes commented debug helpers you can print, such as index info or all stored docs. + +## Key concepts + +### Memory scoping + +- Global scope: set `application_id`, `agent_id`, `user_id`, or `thread_id` on the provider to filter memory. +- Per‑operation thread scope: set `scope_to_per_operation_thread_id=True` to isolate memory to the current thread created by the framework. + +### Hybrid vector search (optional) + +- Enable by setting `vectorizer_choice` to `"openai"` (requires `OPENAI_API_KEY`) or `"hf"` (offline model). +- Provide `vector_field_name` (e.g., `"vector"`); other vector settings have sensible defaults. + +### Index lifecycle controls + +- `overwrite_redis_index` and `drop_redis_index` help recreate indexes during iteration. + +## Troubleshooting + +- Ensure at least one of `application_id`, `agent_id`, `user_id`, or `thread_id` is set; the provider requires a scope. +- If using embeddings, verify `OPENAI_API_KEY` is set and reachable. +- Make sure Redis exposes RediSearch (Redis Stack image or managed service with search enabled). diff --git a/python/samples/getting_started/sessions/redis/azure_redis_conversation.py b/python/samples/getting_started/sessions/redis/azure_redis_conversation.py new file mode 100644 index 0000000000..ce569be8cb --- /dev/null +++ b/python/samples/getting_started/sessions/redis/azure_redis_conversation.py @@ -0,0 +1,123 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Azure Managed Redis History Provider with Azure AD Authentication + +This example demonstrates how to use Azure Managed Redis with Azure AD authentication +to persist conversational details using RedisHistoryProvider. + +Requirements: + - Azure Managed Redis instance with Azure AD authentication enabled + - Azure credentials configured (az login or managed identity) + - agent-framework-redis: pip install agent-framework-redis + - azure-identity: pip install azure-identity + +Environment Variables: + - AZURE_REDIS_HOST: Your Azure Managed Redis host (e.g., myredis.redis.cache.windows.net) + - OPENAI_API_KEY: Your OpenAI API key + - OPENAI_CHAT_MODEL_ID: OpenAI model (e.g., gpt-4o-mini) + - AZURE_USER_OBJECT_ID: Your Azure AD User Object ID for authentication +""" + +import asyncio +import os + +from agent_framework.openai import OpenAIChatClient +from agent_framework.redis import RedisHistoryProvider +from azure.identity.aio import AzureCliCredential +from redis.credentials import CredentialProvider + + +class AzureCredentialProvider(CredentialProvider): + """Credential provider for Azure AD authentication with Redis Enterprise.""" + + def __init__(self, azure_credential: AzureCliCredential, user_object_id: str): + self.azure_credential = azure_credential + self.user_object_id = user_object_id + + async def get_credentials_async(self) -> tuple[str] | tuple[str, str]: + """Get Azure AD token for Redis authentication. + + Returns (username, token) where username is the Azure user's Object ID. + """ + token = await self.azure_credential.get_token("https://redis.azure.com/.default") + return (self.user_object_id, token.token) + + +async def main() -> None: + redis_host = os.environ.get("AZURE_REDIS_HOST") + if not redis_host: + print("ERROR: Set AZURE_REDIS_HOST environment variable") + return + + # For Azure Redis with Entra ID, username must be your Object ID + user_object_id = os.environ.get("AZURE_USER_OBJECT_ID") + if not user_object_id: + print("ERROR: Set AZURE_USER_OBJECT_ID environment variable") + print("Get your Object ID from the Azure Portal") + return + + # Create Azure CLI credential provider (uses 'az login' credentials) + azure_credential = AzureCliCredential() + credential_provider = AzureCredentialProvider(azure_credential, user_object_id) + + session_id = "azure_test_session" + + # Create Azure Redis history provider + history_provider = RedisHistoryProvider( + credential_provider=credential_provider, + host=redis_host, + port=10000, + ssl=True, + thread_id=session_id, + key_prefix="chat_messages", + max_messages=100, + ) + + # Create chat client + client = OpenAIChatClient() + + # Create agent with Azure Redis history provider + agent = client.as_agent( + name="AzureRedisAssistant", + instructions="You are a helpful assistant.", + context_providers=[history_provider], + ) + + # Conversation + query = "Remember that I enjoy gumbo" + result = await agent.run(query) + print("User: ", query) + print("Agent: ", result) + + # Ask the agent to recall the stored preference; it should retrieve from memory + query = "What do I enjoy?" + result = await agent.run(query) + print("User: ", query) + print("Agent: ", result) + + query = "What did I say to you just now?" + result = await agent.run(query) + print("User: ", query) + print("Agent: ", result) + + query = "Remember that I have a meeting at 3pm tomorrow" + result = await agent.run(query) + print("User: ", query) + print("Agent: ", result) + + query = "Tulips are red" + result = await agent.run(query) + print("User: ", query) + print("Agent: ", result) + + query = "What was the first thing I said to you this conversation?" + result = await agent.run(query) + print("User: ", query) + print("Agent: ", result) + + # Cleanup + await azure_credential.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/sessions/redis/redis_basics.py b/python/samples/getting_started/sessions/redis/redis_basics.py new file mode 100644 index 0000000000..81238eb171 --- /dev/null +++ b/python/samples/getting_started/sessions/redis/redis_basics.py @@ -0,0 +1,256 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Redis Context Provider: Basic usage and agent integration + +This example demonstrates how to use the Redis context provider to persist and +retrieve conversational memory for agents. It covers three progressively more +realistic scenarios: + +1) Standalone provider usage ("basic cache") + - Write messages to Redis and retrieve relevant context using full-text or + hybrid vector search. + +2) Agent + provider + - Connect the provider to an agent so the agent can store user preferences + and recall them across turns. + +3) Agent + provider + tool memory + - Expose a simple tool to the agent, then verify that details from the tool + outputs are captured and retrievable as part of the agent's memory. + +Requirements: + - A Redis instance with RediSearch enabled (e.g., Redis Stack) + - agent-framework with the Redis extra installed: pip install "agent-framework-redis" + - Optionally an OpenAI API key if enabling embeddings for hybrid search + +Run: + python redis_basics.py +""" + +import asyncio +import os + +from agent_framework import Message, tool +from agent_framework.openai import OpenAIChatClient +from agent_framework.redis import RedisContextProvider +from redisvl.extensions.cache.embeddings import EmbeddingsCache +from redisvl.utils.vectorize import OpenAITextVectorizer + + +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +@tool(approval_mode="never_require") +def search_flights(origin_airport_code: str, destination_airport_code: str, detailed: bool = False) -> str: + """Simulated flight-search tool to demonstrate tool memory. + + The agent can call this function, and the returned details can be stored + by the Redis context provider. We later ask the agent to recall facts from + these tool results to verify memory is working as expected. + """ + # Minimal static catalog used to simulate a tool's structured output + flights = { + ("JFK", "LAX"): { + "airline": "SkyJet", + "duration": "6h 15m", + "price": 325, + "cabin": "Economy", + "baggage": "1 checked bag", + }, + ("SFO", "SEA"): { + "airline": "Pacific Air", + "duration": "2h 5m", + "price": 129, + "cabin": "Economy", + "baggage": "Carry-on only", + }, + ("LHR", "DXB"): { + "airline": "EuroWings", + "duration": "6h 50m", + "price": 499, + "cabin": "Business", + "baggage": "2 bags included", + }, + } + + route = (origin_airport_code.upper(), destination_airport_code.upper()) + if route not in flights: + return f"No flights found between {origin_airport_code} and {destination_airport_code}" + + flight = flights[route] + if not detailed: + return f"Flights available from {origin_airport_code} to {destination_airport_code}." + + return ( + f"{flight['airline']} operates flights from {origin_airport_code} to {destination_airport_code}. " + f"Duration: {flight['duration']}. " + f"Price: ${flight['price']}. " + f"Cabin: {flight['cabin']}. " + f"Baggage policy: {flight['baggage']}." + ) + + +async def main() -> None: + """Walk through provider-only, agent integration, and tool-memory scenarios. + + Helpful debugging (uncomment when iterating): + - print(await provider.redis_index.info()) + - print(await provider.search_all()) + """ + + print("1. Standalone provider usage:") + print("-" * 40) + # Create a provider with partition scope and OpenAI embeddings + + # Please set the OPENAI_API_KEY and OPENAI_CHAT_MODEL_ID environment variables to use the OpenAI vectorizer + # Recommend default for OPENAI_CHAT_MODEL_ID is gpt-4o-mini + + # We attach an embedding vectorizer so the provider can perform hybrid (text + vector) + # retrieval. If you prefer text-only retrieval, instantiate RedisContextProvider without the + # 'vectorizer' and vector_* parameters. + vectorizer = OpenAITextVectorizer( + model="text-embedding-ada-002", + api_config={"api_key": os.getenv("OPENAI_API_KEY")}, + cache=EmbeddingsCache(name="openai_embeddings_cache", redis_url="redis://localhost:6379"), + ) + # The provider manages persistence and retrieval. application_id/agent_id/user_id + # scope data for multi-tenant separation; thread_id (set later) narrows to a + # specific conversation. + provider = RedisContextProvider( + redis_url="redis://localhost:6379", + index_name="redis_basics", + application_id="matrix_of_kermits", + agent_id="agent_kermit", + user_id="kermit", + redis_vectorizer=vectorizer, + vector_field_name="vector", + vector_algorithm="hnsw", + vector_distance_metric="cosine", + ) + + # Build sample chat messages to persist to Redis + messages = [ + Message("user", ["runA CONVO: User Message"]), + Message("assistant", ["runA CONVO: Assistant Message"]), + Message("system", ["runA CONVO: System Message"]), + ] + + # Use the provider's before_run/after_run API to store and retrieve messages. + # In practice, the agent handles this automatically; this shows the low-level API. + from agent_framework import AgentSession, SessionContext + + session = AgentSession(session_id="runA") + context = SessionContext() + context.extend_messages("input", messages) + state = session.state + + # Store messages via after_run + await provider.after_run(agent=None, session=session, context=context, state=state) + + # Retrieve relevant memories via before_run + query_context = SessionContext() + query_context.extend_messages("input", [Message("system", ["B: Assistant Message"])]) + await provider.before_run(agent=None, session=session, context=query_context, state=state) + + # Inspect retrieved memories that would be injected into instructions + # (Debug-only output so you can verify retrieval works as expected.) + print("Before Run Result:") + print(query_context) + + # Drop / delete the provider index in Redis + await provider.redis_index.delete() + + # --- Agent + provider: teach and recall a preference --- + + print("\n2. Agent + provider: teach and recall a preference") + print("-" * 40) + # Fresh provider for the agent demo (recreates index) + vectorizer = OpenAITextVectorizer( + model="text-embedding-ada-002", + api_config={"api_key": os.getenv("OPENAI_API_KEY")}, + cache=EmbeddingsCache(name="openai_embeddings_cache", redis_url="redis://localhost:6379"), + ) + # Recreate a clean index so the next scenario starts fresh + provider = RedisContextProvider( + redis_url="redis://localhost:6379", + index_name="redis_basics_2", + prefix="context_2", + application_id="matrix_of_kermits", + agent_id="agent_kermit", + user_id="kermit", + redis_vectorizer=vectorizer, + vector_field_name="vector", + vector_algorithm="hnsw", + vector_distance_metric="cosine", + ) + + # Create chat client for the agent + client = OpenAIChatClient(model_id=os.getenv("OPENAI_CHAT_MODEL_ID"), api_key=os.getenv("OPENAI_API_KEY")) + # Create agent wired to the Redis context provider. The provider automatically + # persists conversational details and surfaces relevant context on each turn. + agent = client.as_agent( + name="MemoryEnhancedAssistant", + instructions=( + "You are a helpful assistant. Personalize replies using provided context. " + "Before answering, always check for stored context" + ), + tools=[], + context_providers=[provider], + ) + + # Teach a user preference; the agent writes this to the provider's memory + query = "Remember that I enjoy glugenflorgle" + result = await agent.run(query) + print("User: ", query) + print("Agent: ", result) + + # Ask the agent to recall the stored preference; it should retrieve from memory + query = "What do I enjoy?" + result = await agent.run(query) + print("User: ", query) + print("Agent: ", result) + + # Drop / delete the provider index in Redis + await provider.redis_index.delete() + + # --- Agent + provider + tool: store and recall tool-derived context --- + + print("\n3. Agent + provider + tool: store and recall tool-derived context") + print("-" * 40) + # Text-only provider (full-text search only). Omits vectorizer and related params. + provider = RedisContextProvider( + redis_url="redis://localhost:6379", + index_name="redis_basics_3", + prefix="context_3", + application_id="matrix_of_kermits", + agent_id="agent_kermit", + user_id="kermit", + ) + + # Create agent exposing the flight search tool. Tool outputs are captured by the + # provider and become retrievable context for later turns. + client = OpenAIChatClient(model_id=os.getenv("OPENAI_CHAT_MODEL_ID"), api_key=os.getenv("OPENAI_API_KEY")) + agent = client.as_agent( + name="MemoryEnhancedAssistant", + instructions=( + "You are a helpful assistant. Personalize replies using provided context. " + "Before answering, always check for stored context" + ), + tools=search_flights, + context_providers=[provider], + ) + # Invoke the tool; outputs become part of memory/context + query = "Are there any flights from new york city (jfk) to la? Give me details" + result = await agent.run(query) + print("User: ", query) + print("Agent: ", result) + # Verify the agent can recall tool-derived context + query = "Which flight did I ask about?" + result = await agent.run(query) + print("User: ", query) + print("Agent: ", result) + + # Drop / delete the provider index in Redis + await provider.redis_index.delete() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/sessions/redis/redis_conversation.py b/python/samples/getting_started/sessions/redis/redis_conversation.py new file mode 100644 index 0000000000..2d345d9930 --- /dev/null +++ b/python/samples/getting_started/sessions/redis/redis_conversation.py @@ -0,0 +1,105 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Redis Context Provider: Basic usage and agent integration + +This example demonstrates how to use the Redis context provider to persist +conversational details. Pass it as a constructor argument to create_agent. + +Requirements: + - A Redis instance with RediSearch enabled (e.g., Redis Stack) + - agent-framework with the Redis extra installed: pip install "agent-framework-redis" + - Optionally an OpenAI API key if enabling embeddings for hybrid search + +Run: + python redis_conversation.py +""" + +import asyncio +import os + +from agent_framework.openai import OpenAIChatClient +from agent_framework.redis import RedisContextProvider +from redisvl.extensions.cache.embeddings import EmbeddingsCache +from redisvl.utils.vectorize import OpenAITextVectorizer + + +async def main() -> None: + """Walk through provider and chat message store usage. + + Helpful debugging (uncomment when iterating): + - print(await provider.redis_index.info()) + - print(await provider.search_all()) + """ + vectorizer = OpenAITextVectorizer( + model="text-embedding-ada-002", + api_config={"api_key": os.getenv("OPENAI_API_KEY")}, + cache=EmbeddingsCache(name="openai_embeddings_cache", redis_url="redis://localhost:6379"), + ) + + session_id = "test_session" + + provider = RedisContextProvider( + redis_url="redis://localhost:6379", + index_name="redis_conversation", + prefix="redis_conversation", + application_id="matrix_of_kermits", + agent_id="agent_kermit", + user_id="kermit", + redis_vectorizer=vectorizer, + vector_field_name="vector", + vector_algorithm="hnsw", + vector_distance_metric="cosine", + thread_id=session_id, + ) + + # Create chat client for the agent + client = OpenAIChatClient(model_id=os.getenv("OPENAI_CHAT_MODEL_ID"), api_key=os.getenv("OPENAI_API_KEY")) + # Create agent wired to the Redis context provider. The provider automatically + # persists conversational details and surfaces relevant context on each turn. + agent = client.as_agent( + name="MemoryEnhancedAssistant", + instructions=( + "You are a helpful assistant. Personalize replies using provided context. " + "Before answering, always check for stored context" + ), + tools=[], + context_providers=[provider], + ) + + # Teach a user preference; the agent writes this to the provider's memory + query = "Remember that I enjoy gumbo" + result = await agent.run(query) + print("User: ", query) + print("Agent: ", result) + + # Ask the agent to recall the stored preference; it should retrieve from memory + query = "What do I enjoy?" + result = await agent.run(query) + print("User: ", query) + print("Agent: ", result) + + query = "What did I say to you just now?" + result = await agent.run(query) + print("User: ", query) + print("Agent: ", result) + + query = "Remember that I have a meeting at 3pm tomorro" + result = await agent.run(query) + print("User: ", query) + print("Agent: ", result) + + query = "Tulips are red" + result = await agent.run(query) + print("User: ", query) + print("Agent: ", result) + + query = "What was the first thing I said to you this conversation?" + result = await agent.run(query) + print("User: ", query) + print("Agent: ", result) + # Drop / delete the provider index in Redis + await provider.redis_index.delete() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/sessions/redis/redis_threads.py b/python/samples/getting_started/sessions/redis/redis_threads.py new file mode 100644 index 0000000000..34179048d9 --- /dev/null +++ b/python/samples/getting_started/sessions/redis/redis_threads.py @@ -0,0 +1,249 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Redis Context Provider: Thread scoping examples + +This sample demonstrates how conversational memory can be scoped when using the +Redis context provider. It covers three scenarios: + +1) Global thread scope + - Provide a fixed thread_id to share memories across operations/threads. + +2) Per-operation thread scope + - Enable scope_to_per_operation_thread_id to bind the provider to a single + thread for the lifetime of that provider instance. Use the same thread + object for reads/writes with that provider. + +3) Multiple agents with isolated memory + - Use different agent_id values to keep memories separated for different + agent personas, even when the user_id is the same. + +Requirements: + - A Redis instance with RediSearch enabled (e.g., Redis Stack) + - agent-framework with the Redis extra installed: pip install "agent-framework-redis" + - Optionally an OpenAI API key for the chat client in this demo + +Run: + python redis_threads.py +""" + +import asyncio +import os +import uuid + +from agent_framework.openai import OpenAIChatClient +from agent_framework.redis import RedisContextProvider +from redisvl.extensions.cache.embeddings import EmbeddingsCache +from redisvl.utils.vectorize import OpenAITextVectorizer + +# Please set the OPENAI_API_KEY and OPENAI_CHAT_MODEL_ID environment variables to use the OpenAI vectorizer +# Recommend default for OPENAI_CHAT_MODEL_ID is gpt-4o-mini + + +async def example_global_thread_scope() -> None: + """Example 1: Global thread_id scope (memories shared across all operations).""" + print("1. Global Thread Scope Example:") + print("-" * 40) + + global_thread_id = str(uuid.uuid4()) + + client = OpenAIChatClient( + model_id=os.getenv("OPENAI_CHAT_MODEL_ID", "gpt-4o-mini"), + api_key=os.getenv("OPENAI_API_KEY"), + ) + + provider = RedisContextProvider( + redis_url="redis://localhost:6379", + index_name="redis_threads_global", + application_id="threads_demo_app", + agent_id="threads_demo_agent", + user_id="threads_demo_user", + thread_id=global_thread_id, + scope_to_per_operation_thread_id=False, # Share memories across all sessions + ) + + agent = client.as_agent( + name="GlobalMemoryAssistant", + instructions=( + "You are a helpful assistant. Personalize replies using provided context. " + "Before answering, always check for stored context containing information" + ), + tools=[], + context_providers=[provider], + ) + + # Store a preference in the global scope + query = "Remember that I prefer technical responses with code examples when discussing programming." + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result}\n") + + # Create a new session - memories should still be accessible due to global scope + new_session = agent.create_session() + query = "What technical responses do I prefer?" + print(f"User (new session): {query}") + result = await agent.run(query, session=new_session) + print(f"Agent: {result}\n") + + # Clean up the Redis index + await provider.redis_index.delete() + + +async def example_per_operation_thread_scope() -> None: + """Example 2: Per-operation thread scope (memories isolated per session). + + Note: When scope_to_per_operation_thread_id=True, the provider is bound to a single session + throughout its lifetime. Use the same session object for all operations with that provider. + """ + print("2. Per-Operation Thread Scope Example:") + print("-" * 40) + + client = OpenAIChatClient( + model_id=os.getenv("OPENAI_CHAT_MODEL_ID", "gpt-4o-mini"), + api_key=os.getenv("OPENAI_API_KEY"), + ) + + vectorizer = OpenAITextVectorizer( + model="text-embedding-ada-002", + api_config={"api_key": os.getenv("OPENAI_API_KEY")}, + cache=EmbeddingsCache(name="openai_embeddings_cache", redis_url="redis://localhost:6379"), + ) + + provider = RedisContextProvider( + redis_url="redis://localhost:6379", + index_name="redis_threads_dynamic", + # overwrite_redis_index=True, + # drop_redis_index=True, + application_id="threads_demo_app", + agent_id="threads_demo_agent", + user_id="threads_demo_user", + scope_to_per_operation_thread_id=True, # Isolate memories per session + redis_vectorizer=vectorizer, + vector_field_name="vector", + vector_algorithm="hnsw", + vector_distance_metric="cosine", + ) + + agent = client.as_agent( + name="ScopedMemoryAssistant", + instructions="You are an assistant with thread-scoped memory.", + context_providers=[provider], + ) + + # Create a specific session for this scoped provider + dedicated_session = agent.create_session() + + # Store some information in the dedicated session + query = "Remember that for this conversation, I'm working on a Python project about data analysis." + print(f"User (dedicated session): {query}") + result = await agent.run(query, session=dedicated_session) + print(f"Agent: {result}\n") + + # Test memory retrieval in the same dedicated session + query = "What project am I working on?" + print(f"User (same dedicated session): {query}") + result = await agent.run(query, session=dedicated_session) + print(f"Agent: {result}\n") + + # Store more information in the same session + query = "Also remember that I prefer using pandas and matplotlib for this project." + print(f"User (same dedicated session): {query}") + result = await agent.run(query, session=dedicated_session) + print(f"Agent: {result}\n") + + # Test comprehensive memory retrieval + query = "What do you know about my current project and preferences?" + print(f"User (same dedicated session): {query}") + result = await agent.run(query, session=dedicated_session) + print(f"Agent: {result}\n") + + # Clean up the Redis index + await provider.redis_index.delete() + + +async def example_multiple_agents() -> None: + """Example 3: Multiple agents with different thread configurations (isolated via agent_id) but within 1 index.""" + print("3. Multiple Agents with Different Thread Configurations:") + print("-" * 40) + + client = OpenAIChatClient( + model_id=os.getenv("OPENAI_CHAT_MODEL_ID", "gpt-4o-mini"), + api_key=os.getenv("OPENAI_API_KEY"), + ) + + vectorizer = OpenAITextVectorizer( + model="text-embedding-ada-002", + api_config={"api_key": os.getenv("OPENAI_API_KEY")}, + cache=EmbeddingsCache(name="openai_embeddings_cache", redis_url="redis://localhost:6379"), + ) + + personal_provider = RedisContextProvider( + redis_url="redis://localhost:6379", + index_name="redis_threads_agents", + application_id="threads_demo_app", + agent_id="agent_personal", + user_id="threads_demo_user", + redis_vectorizer=vectorizer, + vector_field_name="vector", + vector_algorithm="hnsw", + vector_distance_metric="cosine", + ) + + personal_agent = client.as_agent( + name="PersonalAssistant", + instructions="You are a personal assistant that helps with personal tasks.", + context_providers=[personal_provider], + ) + + work_provider = RedisContextProvider( + redis_url="redis://localhost:6379", + index_name="redis_threads_agents", + application_id="threads_demo_app", + agent_id="agent_work", + user_id="threads_demo_user", + redis_vectorizer=vectorizer, + vector_field_name="vector", + vector_algorithm="hnsw", + vector_distance_metric="cosine", + ) + + work_agent = client.as_agent( + name="WorkAssistant", + instructions="You are a work assistant that helps with professional tasks.", + context_providers=[work_provider], + ) + + # Store personal information + query = "Remember that I like to exercise at 6 AM and prefer outdoor activities." + print(f"User to Personal Agent: {query}") + result = await personal_agent.run(query) + print(f"Personal Agent: {result}\n") + + # Store work information + query = "Remember that I have team meetings every Tuesday at 2 PM." + print(f"User to Work Agent: {query}") + result = await work_agent.run(query) + print(f"Work Agent: {result}\n") + + # Test memory isolation + query = "What do you know about my schedule?" + print(f"User to Personal Agent: {query}") + result = await personal_agent.run(query) + print(f"Personal Agent: {result}\n") + + print(f"User to Work Agent: {query}") + result = await work_agent.run(query) + print(f"Work Agent: {result}\n") + + # Clean up the Redis index (shared) + await work_provider.redis_index.delete() + + +async def main() -> None: + print("=== Redis Thread Scoping Examples ===\n") + await example_global_thread_scope() + await example_per_operation_thread_scope() + await example_multiple_agents() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/sessions/redis_chat_message_store_thread.py b/python/samples/getting_started/sessions/redis_chat_message_store_thread.py new file mode 100644 index 0000000000..f54edd8170 --- /dev/null +++ b/python/samples/getting_started/sessions/redis_chat_message_store_thread.py @@ -0,0 +1,257 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import os +from uuid import uuid4 + +from agent_framework import AgentSession +from agent_framework.openai import OpenAIChatClient +from agent_framework.redis import RedisHistoryProvider + +""" +Redis History Provider Session Example + +This sample demonstrates how to use Redis as a history provider for session +management, enabling persistent conversation history storage across sessions +with Redis as the backend data store. +""" + + +async def example_manual_memory_store() -> None: + """Basic example of using Redis history provider.""" + print("=== Basic Redis History Provider Example ===") + + # Create Redis history provider + redis_provider = RedisHistoryProvider( + source_id="redis_basic_chat", + redis_url="redis://localhost:6379", + ) + + # Create agent with Redis history provider + agent = OpenAIChatClient().as_agent( + name="RedisBot", + instructions="You are a helpful assistant that remembers our conversation using Redis.", + context_providers=[redis_provider], + ) + + # Create session + session = agent.create_session() + + # Have a conversation + print("\n--- Starting conversation ---") + query1 = "Hello! My name is Alice and I love pizza." + print(f"User: {query1}") + response1 = await agent.run(query1, session=session) + print(f"Agent: {response1.text}") + + query2 = "What do you remember about me?" + print(f"User: {query2}") + response2 = await agent.run(query2, session=session) + print(f"Agent: {response2.text}") + + print("Done\n") + + +async def example_user_session_management() -> None: + """Example of managing user sessions with Redis.""" + print("=== User Session Management Example ===") + + user_id = "alice_123" + session_id = f"session_{uuid4()}" + + # Create Redis history provider for specific user session + redis_provider = RedisHistoryProvider( + source_id=f"redis_{user_id}", + redis_url="redis://localhost:6379", + max_messages=10, # Keep only last 10 messages + ) + + # Create agent with history provider + agent = OpenAIChatClient().as_agent( + name="SessionBot", + instructions="You are a helpful assistant. Keep track of user preferences.", + context_providers=[redis_provider], + ) + + # Start conversation + session = agent.create_session(session_id=session_id) + + print(f"Started session for user {user_id}") + + # Simulate conversation + queries = [ + "Hi, I'm Alice and I prefer vegetarian food.", + "What restaurants would you recommend?", + "I also love Italian cuisine.", + "Can you remember my food preferences?", + ] + + for i, query in enumerate(queries, 1): + print(f"\n--- Message {i} ---") + print(f"User: {query}") + response = await agent.run(query, session=session) + print(f"Agent: {response.text}") + + print("Done\n") + + +async def example_conversation_persistence() -> None: + """Example of conversation persistence across application restarts.""" + print("=== Conversation Persistence Example ===") + + # Phase 1: Start conversation + print("--- Phase 1: Starting conversation ---") + redis_provider = RedisHistoryProvider( + source_id="redis_persistent_chat", + redis_url="redis://localhost:6379", + ) + + agent = OpenAIChatClient().as_agent( + name="PersistentBot", + instructions="You are a helpful assistant. Remember our conversation history.", + context_providers=[redis_provider], + ) + + session = agent.create_session() + + # Start conversation + query1 = "Hello! I'm working on a Python project about machine learning." + print(f"User: {query1}") + response1 = await agent.run(query1, session=session) + print(f"Agent: {response1.text}") + + query2 = "I'm specifically interested in neural networks." + print(f"User: {query2}") + response2 = await agent.run(query2, session=session) + print(f"Agent: {response2.text}") + + # Serialize session state + serialized = session.to_dict() + + # Phase 2: Resume conversation (simulating app restart) + print("\n--- Phase 2: Resuming conversation (after 'restart') ---") + restored_session = AgentSession.from_dict(serialized) + + # Continue conversation - agent should remember context + query3 = "What was I working on before?" + print(f"User: {query3}") + response3 = await agent.run(query3, session=restored_session) + print(f"Agent: {response3.text}") + + query4 = "Can you suggest some Python libraries for neural networks?" + print(f"User: {query4}") + response4 = await agent.run(query4, session=restored_session) + print(f"Agent: {response4.text}") + + print("Done\n") + + +async def example_session_serialization() -> None: + """Example of session state serialization and deserialization.""" + print("=== Session Serialization Example ===") + + redis_provider = RedisHistoryProvider( + source_id="redis_serialization_chat", + redis_url="redis://localhost:6379", + ) + + agent = OpenAIChatClient().as_agent( + name="SerializationBot", + instructions="You are a helpful assistant.", + context_providers=[redis_provider], + ) + + session = agent.create_session() + + # Have initial conversation + print("--- Initial conversation ---") + query1 = "Hello! I'm testing serialization." + print(f"User: {query1}") + response1 = await agent.run(query1, session=session) + print(f"Agent: {response1.text}") + + # Serialize session state + serialized = session.to_dict() + print(f"\nSerialized session state: {serialized}") + + # Deserialize session state (simulating loading from database/file) + print("\n--- Deserializing session state ---") + restored_session = AgentSession.from_dict(serialized) + + # Continue conversation with restored session + query2 = "Do you remember what I said about testing?" + print(f"User: {query2}") + response2 = await agent.run(query2, session=restored_session) + print(f"Agent: {response2.text}") + + print("Done\n") + + +async def example_message_limits() -> None: + """Example of automatic message trimming with limits.""" + print("=== Message Limits Example ===") + + # Create provider with small message limit + redis_provider = RedisHistoryProvider( + source_id="redis_limited_chat", + redis_url="redis://localhost:6379", + max_messages=3, # Keep only 3 most recent messages + ) + + agent = OpenAIChatClient().as_agent( + name="LimitBot", + instructions="You are a helpful assistant with limited memory.", + context_providers=[redis_provider], + ) + + session = agent.create_session() + + # Send multiple messages to test trimming + messages = [ + "Message 1: Hello!", + "Message 2: How are you?", + "Message 3: What's the weather?", + "Message 4: Tell me a joke.", + "Message 5: This should trigger trimming.", + ] + + for i, query in enumerate(messages, 1): + print(f"\n--- Sending message {i} ---") + print(f"User: {query}") + response = await agent.run(query, session=session) + print(f"Agent: {response.text}") + + print("Done\n") + + +async def main() -> None: + """Run all Redis history provider examples.""" + print("Redis History Provider Examples") + print("=" * 50) + print("Prerequisites:") + print("- Redis server running on localhost:6379") + print("- OPENAI_API_KEY environment variable set") + print("=" * 50) + + # Check prerequisites + if not os.getenv("OPENAI_API_KEY"): + print("ERROR: OPENAI_API_KEY environment variable not set") + return + + try: + # Run all examples + await example_manual_memory_store() + await example_user_session_management() + await example_conversation_persistence() + await example_session_serialization() + await example_message_limits() + + print("All examples completed successfully!") + + except Exception as e: + print(f"Error running examples: {e}") + raise + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/sessions/simple_context_provider.py b/python/samples/getting_started/sessions/simple_context_provider.py new file mode 100644 index 0000000000..fd2a7ce747 --- /dev/null +++ b/python/samples/getting_started/sessions/simple_context_provider.py @@ -0,0 +1,129 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from typing import Any + +from agent_framework import Agent, AgentSession, BaseContextProvider, SessionContext, SupportsChatGetResponse +from agent_framework.azure import AzureAIClient +from azure.identity.aio import AzureCliCredential +from pydantic import BaseModel + + +class UserInfo(BaseModel): + name: str | None = None + age: int | None = None + + +class UserInfoMemory(BaseContextProvider): + def __init__(self, client: SupportsChatGetResponse, user_info: UserInfo | None = None, **kwargs: Any): + """Create the memory. + + If you pass in kwargs, they will be attempted to be used to create a UserInfo object. + """ + super().__init__("user-info-memory") + self._chat_client = client + if user_info: + self.user_info = user_info + elif kwargs: + self.user_info = UserInfo.model_validate(kwargs) + else: + self.user_info = UserInfo() + + async def after_run( + self, + *, + agent: Any, + session: AgentSession | None, + context: SessionContext, + state: dict[str, Any], + ) -> None: + """Extract user information from messages after each agent call.""" + request_messages = context.get_messages() + # Check if we need to extract user info from user messages + user_messages = [msg for msg in request_messages if hasattr(msg, "role") and msg.role == "user"] # type: ignore + + if (self.user_info.name is None or self.user_info.age is None) and user_messages: + try: + # Use the chat client to extract structured information + result = await self._chat_client.get_response( + messages=request_messages, # type: ignore + instructions="Extract the user's name and age from the message if present. " + "If not present return nulls.", + options={"response_format": UserInfo}, + ) + + # Update user info with extracted data + try: + extracted = result.value + if self.user_info.name is None and extracted.name: + self.user_info.name = extracted.name + if self.user_info.age is None and extracted.age: + self.user_info.age = extracted.age + except Exception: + pass # Failed to extract, continue without updating + + except Exception: + pass # Failed to extract, continue without updating + + async def before_run( + self, + *, + agent: Any, + session: AgentSession | None, + context: SessionContext, + state: dict[str, Any], + ) -> None: + """Provide user information context before each agent call.""" + instructions: list[str] = [] + + if self.user_info.name is None: + instructions.append( + "Ask the user for their name and politely decline to answer any questions until they provide it." + ) + else: + instructions.append(f"The user's name is {self.user_info.name}.") + + if self.user_info.age is None: + instructions.append( + "Ask the user for their age and politely decline to answer any questions until they provide it." + ) + else: + instructions.append(f"The user's age is {self.user_info.age}.") + + # Add context with additional instructions + context.extend_instructions(self.source_id, " ".join(instructions)) + + def serialize(self) -> str: + """Serialize the user info for session persistence.""" + return self.user_info.model_dump_json() + + +async def main(): + async with AzureCliCredential() as credential: + client = AzureAIClient(credential=credential) + + # Create the memory provider + memory_provider = UserInfoMemory(client) + + # Create the agent with memory + async with Agent( + client=client, + instructions="You are a friendly assistant. Always address the user by their name.", + context_providers=[memory_provider], + ) as agent: + # Create a new session for the conversation + session = agent.create_session() + + print(await agent.run("Hello, what is the square root of 9?", session=session)) + print(await agent.run("My name is Ruaidhrí", session=session)) + print(await agent.run("I am 20 years old", session=session)) + + # Access the memory component and inspect the memories + if memory_provider: + print() + print(f"MEMORY - User Name: {memory_provider.user_info.name}") + print(f"MEMORY - User Age: {memory_provider.user_info.age}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/sessions/suspend_resume_thread.py b/python/samples/getting_started/sessions/suspend_resume_thread.py new file mode 100644 index 0000000000..dcbb00d06a --- /dev/null +++ b/python/samples/getting_started/sessions/suspend_resume_thread.py @@ -0,0 +1,93 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio + +from agent_framework import AgentSession +from agent_framework.azure import AzureAIAgentClient +from agent_framework.openai import OpenAIChatClient +from azure.identity.aio import AzureCliCredential + +""" +Session Suspend and Resume Example + +This sample demonstrates how to suspend and resume conversation sessions, comparing +service-managed sessions (Azure AI) with in-memory sessions (OpenAI) for persistent +conversation state across sessions. +""" + + +async def suspend_resume_service_managed_session() -> None: + """Demonstrates how to suspend and resume a service-managed session.""" + print("=== Suspend-Resume Service-Managed Session ===") + + # AzureAIAgentClient supports service-managed sessions. + async with ( + AzureCliCredential() as credential, + AzureAIAgentClient(credential=credential).as_agent( + name="MemoryBot", instructions="You are a helpful assistant that remembers our conversation." + ) as agent, + ): + # Start a new session for the agent conversation. + session = agent.create_session() + + # Respond to user input. + query = "Hello! My name is Alice and I love pizza." + print(f"User: {query}") + print(f"Agent: {await agent.run(query, session=session)}\n") + + # Serialize the session state, so it can be stored for later use. + serialized_session = session.to_dict() + + # The session can now be saved to a database, file, or any other storage mechanism and loaded again later. + print(f"Serialized session: {serialized_session}\n") + + # Deserialize the session state after loading from storage. + resumed_session = AgentSession.from_dict(serialized_session) + + # Respond to user input. + query = "What do you remember about me?" + print(f"User: {query}") + print(f"Agent: {await agent.run(query, session=resumed_session)}\n") + + +async def suspend_resume_in_memory_session() -> None: + """Demonstrates how to suspend and resume an in-memory session.""" + print("=== Suspend-Resume In-Memory Session ===") + + # OpenAI Chat Client is used as an example here, + # other chat clients can be used as well. + agent = OpenAIChatClient().as_agent( + name="MemoryBot", instructions="You are a helpful assistant that remembers our conversation." + ) + + # Start a new session for the agent conversation. + session = agent.create_session() + + # Respond to user input. + query = "Hello! My name is Alice and I love pizza." + print(f"User: {query}") + print(f"Agent: {await agent.run(query, session=session)}\n") + + # Serialize the session state, so it can be stored for later use. + serialized_session = session.to_dict() + + # The session can now be saved to a database, file, or any other storage mechanism and loaded again later. + print(f"Serialized session: {serialized_session}\n") + + # Deserialize the session state after loading from storage. + resumed_session = AgentSession.from_dict(serialized_session) + + # Respond to user input. + query = "What do you remember about me?" + print(f"User: {query}") + print(f"Agent: {await agent.run(query, session=resumed_session)}\n") + + +async def main() -> None: + print("=== Suspend-Resume Session Examples ===") + await suspend_resume_service_managed_session() + await suspend_resume_in_memory_session() + + +if __name__ == "__main__": + asyncio.run(main()) From b42159d9a28dbee00cb00d145787c204ad6fb3d7 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Feb 2026 15:33:05 +0100 Subject: [PATCH 14/28] refactor: UserInfoMemory stores state in session.state instead of instance attributes --- .../sessions/simple_context_provider.py | 98 +++++++++---------- 1 file changed, 47 insertions(+), 51 deletions(-) diff --git a/python/samples/getting_started/sessions/simple_context_provider.py b/python/samples/getting_started/sessions/simple_context_provider.py index fd2a7ce747..0bf52197a5 100644 --- a/python/samples/getting_started/sessions/simple_context_provider.py +++ b/python/samples/getting_started/sessions/simple_context_provider.py @@ -15,19 +15,24 @@ class UserInfo(BaseModel): class UserInfoMemory(BaseContextProvider): - def __init__(self, client: SupportsChatGetResponse, user_info: UserInfo | None = None, **kwargs: Any): - """Create the memory. + """Context provider that extracts and remembers user info (name, age). - If you pass in kwargs, they will be attempted to be used to create a UserInfo object. - """ + State is stored in ``session.state["user-info-memory"]`` so it survives + serialization via ``session.to_dict()`` / ``AgentSession.from_dict()``. + """ + + def __init__(self, client: SupportsChatGetResponse): super().__init__("user-info-memory") self._chat_client = client - if user_info: - self.user_info = user_info - elif kwargs: - self.user_info = UserInfo.model_validate(kwargs) - else: - self.user_info = UserInfo() + + def _get_user_info(self, state: dict[str, Any]) -> UserInfo: + """Load UserInfo from session state, creating it if absent.""" + my_state = state.setdefault(self.source_id, {}) + return UserInfo.model_validate(my_state.get("user_info", {})) + + def _save_user_info(self, state: dict[str, Any], user_info: UserInfo) -> None: + """Persist UserInfo back to session state.""" + state.setdefault(self.source_id, {})["user_info"] = user_info.model_dump() async def after_run( self, @@ -38,32 +43,30 @@ async def after_run( state: dict[str, Any], ) -> None: """Extract user information from messages after each agent call.""" + user_info = self._get_user_info(state) + if user_info.name is not None and user_info.age is not None: + return # Already have everything + request_messages = context.get_messages() - # Check if we need to extract user info from user messages user_messages = [msg for msg in request_messages if hasattr(msg, "role") and msg.role == "user"] # type: ignore - - if (self.user_info.name is None or self.user_info.age is None) and user_messages: - try: - # Use the chat client to extract structured information - result = await self._chat_client.get_response( - messages=request_messages, # type: ignore - instructions="Extract the user's name and age from the message if present. " - "If not present return nulls.", - options={"response_format": UserInfo}, - ) - - # Update user info with extracted data - try: - extracted = result.value - if self.user_info.name is None and extracted.name: - self.user_info.name = extracted.name - if self.user_info.age is None and extracted.age: - self.user_info.age = extracted.age - except Exception: - pass # Failed to extract, continue without updating - - except Exception: - pass # Failed to extract, continue without updating + if not user_messages: + return + + try: + result = await self._chat_client.get_response( + messages=request_messages, # type: ignore + instructions="Extract the user's name and age from the message if present. " + "If not present return nulls.", + options={"response_format": UserInfo}, + ) + extracted = result.value + if user_info.name is None and extracted.name: + user_info.name = extracted.name + if user_info.age is None and extracted.age: + user_info.age = extracted.age + self._save_user_info(state, user_info) + except Exception: + pass # Failed to extract, continue without updating async def before_run( self, @@ -74,55 +77,48 @@ async def before_run( state: dict[str, Any], ) -> None: """Provide user information context before each agent call.""" + user_info = self._get_user_info(state) instructions: list[str] = [] - if self.user_info.name is None: + if user_info.name is None: instructions.append( "Ask the user for their name and politely decline to answer any questions until they provide it." ) else: - instructions.append(f"The user's name is {self.user_info.name}.") + instructions.append(f"The user's name is {user_info.name}.") - if self.user_info.age is None: + if user_info.age is None: instructions.append( "Ask the user for their age and politely decline to answer any questions until they provide it." ) else: - instructions.append(f"The user's age is {self.user_info.age}.") + instructions.append(f"The user's age is {user_info.age}.") - # Add context with additional instructions context.extend_instructions(self.source_id, " ".join(instructions)) - def serialize(self) -> str: - """Serialize the user info for session persistence.""" - return self.user_info.model_dump_json() - async def main(): async with AzureCliCredential() as credential: client = AzureAIClient(credential=credential) - # Create the memory provider memory_provider = UserInfoMemory(client) - # Create the agent with memory async with Agent( client=client, instructions="You are a friendly assistant. Always address the user by their name.", context_providers=[memory_provider], ) as agent: - # Create a new session for the conversation session = agent.create_session() print(await agent.run("Hello, what is the square root of 9?", session=session)) print(await agent.run("My name is Ruaidhrí", session=session)) print(await agent.run("I am 20 years old", session=session)) - # Access the memory component and inspect the memories - if memory_provider: - print() - print(f"MEMORY - User Name: {memory_provider.user_info.name}") - print(f"MEMORY - User Age: {memory_provider.user_info.age}") + # Inspect extracted user info from session state + user_info = memory_provider._get_user_info(session.state) + print() + print(f"MEMORY - User Name: {user_info.name}") + print(f"MEMORY - User Age: {user_info.age}") if __name__ == "__main__": From d95ed0639f6781b892dfa9d58536495935267ec3 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Feb 2026 15:38:47 +0100 Subject: [PATCH 15/28] feat: add Pydantic BaseModel support to session state serialization Pydantic models stored in session.state are now automatically serialized via model_dump() and restored via model_validate() during to_dict()/from_dict() round-trips. Models are auto-registered on first serialization; use register_state_type() for cold-start deserialization. Also export register_state_type as a public API. --- .../core/agent_framework/_sessions.py | 55 +++++++++++++++++-- .../sessions/simple_context_provider.py | 11 +++- 2 files changed, 58 insertions(+), 8 deletions(-) diff --git a/python/packages/core/agent_framework/_sessions.py b/python/packages/core/agent_framework/_sessions.py index 015248d844..6240c632c2 100644 --- a/python/packages/core/agent_framework/_sessions.py +++ b/python/packages/core/agent_framework/_sessions.py @@ -30,6 +30,7 @@ "BaseHistoryProvider", "InMemoryHistoryProvider", "SessionContext", + "register_state_type", ] @@ -37,16 +38,50 @@ _STATE_TYPE_REGISTRY: dict[str, type] = {} -def _register_state_type(cls: type) -> None: - """Register a type for automatic deserialization in session state.""" +def register_state_type(cls: type) -> None: + """Register a type for automatic deserialization in session state. + + Call this for any custom type (including Pydantic models) that you store + in ``session.state`` and want to survive ``to_dict()`` / ``from_dict()`` + round-trips. Types with ``to_dict``/``from_dict`` methods or Pydantic + ``BaseModel`` subclasses are handled automatically. + + The type identifier defaults to ``cls.__name__.lower()`` but can be + overridden by defining a ``_get_type_identifier`` classmethod. + + Note: + Pydantic models are auto-registered on first serialization, but + pre-registering ensures deserialization works even if the model + hasn't been serialized in this process yet (e.g. cold-start restore). + + Args: + cls: The type to register. + """ type_id: str = getattr(cls, "_get_type_identifier", lambda: cls.__name__.lower())() _STATE_TYPE_REGISTRY[type_id] = cls +# Keep internal alias for framework use +_register_state_type = register_state_type + + def _serialize_value(value: Any) -> Any: - """Serialize a single value, handling objects with to_dict().""" + """Serialize a single value, handling objects with to_dict() and Pydantic models.""" if hasattr(value, "to_dict") and callable(value.to_dict): return value.to_dict() # pyright: ignore[reportUnknownMemberType] + # Pydantic BaseModel support — import lazily to avoid hard dep at module level + try: + from pydantic import BaseModel + + if isinstance(value, BaseModel): + data = value.model_dump() + type_id: str = getattr(value.__class__, "_get_type_identifier", lambda: value.__class__.__name__.lower())() + data["type"] = type_id + # Auto-register for round-trip deserialization + _STATE_TYPE_REGISTRY.setdefault(type_id, value.__class__) + return data + except ImportError: + pass if isinstance(value, list): return [_serialize_value(item) for item in value] # pyright: ignore[reportUnknownVariableType] if isinstance(value, dict): @@ -59,8 +94,18 @@ def _deserialize_value(value: Any) -> Any: if isinstance(value, dict) and "type" in value: type_id = str(value["type"]) # pyright: ignore[reportUnknownArgumentType] cls = _STATE_TYPE_REGISTRY.get(type_id) - if cls is not None and hasattr(cls, "from_dict"): - return cls.from_dict(value) # type: ignore[union-attr] + if cls is not None: + if hasattr(cls, "from_dict"): + return cls.from_dict(value) # type: ignore[union-attr] + # Pydantic BaseModel support + try: + from pydantic import BaseModel + + if issubclass(cls, BaseModel): + data = {k: v for k, v in value.items() if k != "type"} + return cls.model_validate(data) + except ImportError: + pass if isinstance(value, list): return [_deserialize_value(item) for item in value] # pyright: ignore[reportUnknownVariableType] if isinstance(value, dict): diff --git a/python/samples/getting_started/sessions/simple_context_provider.py b/python/samples/getting_started/sessions/simple_context_provider.py index 0bf52197a5..fb0c62e2b8 100644 --- a/python/samples/getting_started/sessions/simple_context_provider.py +++ b/python/samples/getting_started/sessions/simple_context_provider.py @@ -28,11 +28,16 @@ def __init__(self, client: SupportsChatGetResponse): def _get_user_info(self, state: dict[str, Any]) -> UserInfo: """Load UserInfo from session state, creating it if absent.""" my_state = state.setdefault(self.source_id, {}) - return UserInfo.model_validate(my_state.get("user_info", {})) + info = my_state.get("user_info") + if isinstance(info, UserInfo): + return info + user_info = UserInfo() + my_state["user_info"] = user_info + return user_info def _save_user_info(self, state: dict[str, Any], user_info: UserInfo) -> None: - """Persist UserInfo back to session state.""" - state.setdefault(self.source_id, {})["user_info"] = user_info.model_dump() + """Persist UserInfo back to session state (stored as-is; serialized automatically).""" + state.setdefault(self.source_id, {})["user_info"] = user_info async def after_run( self, From ccade02ee381a5fcdcf53f861c26c58c4f671204 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Feb 2026 16:09:09 +0100 Subject: [PATCH 16/28] fix mem0 --- python/packages/mem0/README.md | 2 +- .../sessions/simple_context_provider.py | 129 +++++++++--------- 2 files changed, 62 insertions(+), 69 deletions(-) diff --git a/python/packages/mem0/README.md b/python/packages/mem0/README.md index 6ca522f4c8..4d9cb64530 100644 --- a/python/packages/mem0/README.md +++ b/python/packages/mem0/README.md @@ -27,5 +27,5 @@ Mem0's telemetry is **disabled by default** when using this package. If you want import os os.environ["MEM0_TELEMETRY"] = "true" -from agent_framework.mem0 import Mem0Provider +from agent_framework.mem0 import Mem0ContextProvider ``` diff --git a/python/samples/getting_started/sessions/simple_context_provider.py b/python/samples/getting_started/sessions/simple_context_provider.py index fb0c62e2b8..7ef1ba6ea4 100644 --- a/python/samples/getting_started/sessions/simple_context_provider.py +++ b/python/samples/getting_started/sessions/simple_context_provider.py @@ -3,9 +3,15 @@ import asyncio from typing import Any -from agent_framework import Agent, AgentSession, BaseContextProvider, SessionContext, SupportsChatGetResponse -from agent_framework.azure import AzureAIClient -from azure.identity.aio import AzureCliCredential +from agent_framework import ( + Agent, + AgentSession, + BaseContextProvider, + SessionContext, + SupportsChatGetResponse, +) +from agent_framework.azure import AzureOpenAIResponsesClient +from azure.identity import AzureCliCredential from pydantic import BaseModel @@ -25,19 +31,35 @@ def __init__(self, client: SupportsChatGetResponse): super().__init__("user-info-memory") self._chat_client = client - def _get_user_info(self, state: dict[str, Any]) -> UserInfo: - """Load UserInfo from session state, creating it if absent.""" + async def before_run( + self, + *, + agent: Any, + session: AgentSession | None, + context: SessionContext, + state: dict[str, Any], + ) -> None: + """Provide user information context before each agent call.""" my_state = state.setdefault(self.source_id, {}) - info = my_state.get("user_info") - if isinstance(info, UserInfo): - return info - user_info = UserInfo() - my_state["user_info"] = user_info - return user_info + user_info = my_state.setdefault("user_info", UserInfo()) - def _save_user_info(self, state: dict[str, Any], user_info: UserInfo) -> None: - """Persist UserInfo back to session state (stored as-is; serialized automatically).""" - state.setdefault(self.source_id, {})["user_info"] = user_info + instructions: list[str] = [] + + if user_info.name is None: + instructions.append( + "Ask the user for their name and politely decline to answer any questions until they provide it." + ) + else: + instructions.append(f"The user's name is {user_info.name}.") + + if user_info.age is None: + instructions.append( + "Ask the user for their age and politely decline to answer any questions until they provide it." + ) + else: + instructions.append(f"The user's age is {user_info.age}.") + + context.extend_instructions(self.source_id, " ".join(instructions)) async def after_run( self, @@ -48,11 +70,12 @@ async def after_run( state: dict[str, Any], ) -> None: """Extract user information from messages after each agent call.""" - user_info = self._get_user_info(state) + my_state = state.setdefault(self.source_id, {}) + user_info = my_state.setdefault("user_info", UserInfo()) if user_info.name is not None and user_info.age is not None: return # Already have everything - request_messages = context.get_messages() + request_messages = context.get_messages(include_input=True, include_response=True) user_messages = [msg for msg in request_messages if hasattr(msg, "role") and msg.role == "user"] # type: ignore if not user_messages: return @@ -65,65 +88,35 @@ async def after_run( options={"response_format": UserInfo}, ) extracted = result.value - if user_info.name is None and extracted.name: + if extracted and user_info.name is None and extracted.name: user_info.name = extracted.name - if user_info.age is None and extracted.age: + if extracted and user_info.age is None and extracted.age: user_info.age = extracted.age - self._save_user_info(state, user_info) + state.setdefault(self.source_id, {})["user_info"] = user_info except Exception: pass # Failed to extract, continue without updating - async def before_run( - self, - *, - agent: Any, - session: AgentSession | None, - context: SessionContext, - state: dict[str, Any], - ) -> None: - """Provide user information context before each agent call.""" - user_info = self._get_user_info(state) - instructions: list[str] = [] - - if user_info.name is None: - instructions.append( - "Ask the user for their name and politely decline to answer any questions until they provide it." - ) - else: - instructions.append(f"The user's name is {user_info.name}.") - - if user_info.age is None: - instructions.append( - "Ask the user for their age and politely decline to answer any questions until they provide it." - ) - else: - instructions.append(f"The user's age is {user_info.age}.") - - context.extend_instructions(self.source_id, " ".join(instructions)) - async def main(): - async with AzureCliCredential() as credential: - client = AzureAIClient(credential=credential) - - memory_provider = UserInfoMemory(client) - - async with Agent( - client=client, - instructions="You are a friendly assistant. Always address the user by their name.", - context_providers=[memory_provider], - ) as agent: - session = agent.create_session() - - print(await agent.run("Hello, what is the square root of 9?", session=session)) - print(await agent.run("My name is Ruaidhrí", session=session)) - print(await agent.run("I am 20 years old", session=session)) - - # Inspect extracted user info from session state - user_info = memory_provider._get_user_info(session.state) - print() - print(f"MEMORY - User Name: {user_info.name}") - print(f"MEMORY - User Age: {user_info.age}") + client = AzureOpenAIResponsesClient(credential=AzureCliCredential()) + + async with Agent( + client=client, + instructions="You are a friendly assistant. Always address the user by their name.", + default_options={"store": True}, + context_providers=[UserInfoMemory(client)], + ) as agent: + session = agent.create_session() + + print(await agent.run("Hello, what is the square root of 9?", session=session)) + print(await agent.run("My name is Ruaidhrí", session=session)) + print(await agent.run("I am 20 years old", session=session)) + + # Inspect extracted user info from session state + user_info = session.state.get("user-info-memory", {}).get("user_info", UserInfo()) + print() + print(f"MEMORY - User Name: {user_info.name}") + print(f"MEMORY - User Age: {user_info.age}") if __name__ == "__main__": From 16ef1a4f9332fe5f04219854f224d96753b6cda1 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Feb 2026 16:23:24 +0100 Subject: [PATCH 17/28] Update sample README links and descriptions for session terminology MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace 'thread' with 'session' in sample descriptions across all READMEs - Update file links for renamed samples (mem0_sessions, redis_sessions, etc.) - Fix Threads section → Sessions section in main samples/README.md - Update tools, middleware, workflows, durabletask, azure_functions READMEs - Update architecture diagrams in concepts/tools/README.md - Update migration guides (autogen, semantic-kernel) --- .../samples/02-agents/providers/azure_ai/README.md | 4 ++-- .../02-agents/providers/azure_ai_agent/README.md | 6 +++--- .../02-agents/providers/azure_openai/README.md | 6 +++--- python/samples/02-agents/providers/custom/README.md | 4 ++-- .../02-agents/providers/github_copilot/README.md | 2 +- python/samples/02-agents/providers/openai/README.md | 6 +++--- python/samples/03-workflows/README.md | 2 +- .../azure_functions/01_single_agent/README.md | 2 +- .../azure_functions/02_multi_agent/README.md | 2 +- .../04_single_agent_orchestration_chaining/README.md | 4 ++-- .../04-hosting/durabletask/01_single_agent/README.md | 4 ++-- .../04-hosting/durabletask/02_multi_agent/README.md | 6 +++--- .../04_single_agent_orchestration_chaining/README.md | 6 +++--- .../README.md | 4 ++-- .../07_single_agent_orchestration_hitl/README.md | 2 +- python/samples/autogen-migration/README.md | 2 +- python/samples/getting_started/sessions/README.md | 12 ++++++------ ...ge_store_thread.py => custom_history_provider.py} | 0 .../samples/getting_started/sessions/mem0/README.md | 10 +++++----- .../mem0/{mem0_threads.py => mem0_sessions.py} | 0 .../samples/getting_started/sessions/redis/README.md | 10 +++++----- .../redis/{redis_threads.py => redis_sessions.py} | 0 ...age_store_thread.py => redis_history_provider.py} | 0 ...nd_resume_thread.py => suspend_resume_session.py} | 0 python/samples/semantic-kernel-migration/README.md | 6 +++--- 25 files changed, 50 insertions(+), 50 deletions(-) rename python/samples/getting_started/sessions/{custom_chat_message_store_thread.py => custom_history_provider.py} (100%) rename python/samples/getting_started/sessions/mem0/{mem0_threads.py => mem0_sessions.py} (100%) rename python/samples/getting_started/sessions/redis/{redis_threads.py => redis_sessions.py} (100%) rename python/samples/getting_started/sessions/{redis_chat_message_store_thread.py => redis_history_provider.py} (100%) rename python/samples/getting_started/sessions/{suspend_resume_thread.py => suspend_resume_session.py} (100%) diff --git a/python/samples/02-agents/providers/azure_ai/README.md b/python/samples/02-agents/providers/azure_ai/README.md index f7dfb0f8ce..075219462c 100644 --- a/python/samples/02-agents/providers/azure_ai/README.md +++ b/python/samples/02-agents/providers/azure_ai/README.md @@ -31,7 +31,7 @@ This folder contains examples demonstrating different ways to create and use age | [`azure_ai_with_search_context_agentic.py`](../../sessions/azure_ai_search/azure_ai_with_search_context_agentic.py) | Shows how to use AzureAISearchContextProvider with agentic mode. Uses Knowledge Bases for multi-hop reasoning across documents with query planning. Recommended for most scenarios - slightly slower with more token consumption for query planning, but more accurate results. | | [`azure_ai_with_search_context_semantic.py`](../../sessions/azure_ai_search/azure_ai_with_search_context_semantic.py) | Shows how to use AzureAISearchContextProvider with semantic mode. Fast hybrid search with vector + keyword search and semantic ranking for RAG. Best for simple queries where speed is critical. | | [`azure_ai_with_sharepoint.py`](azure_ai_with_sharepoint.py) | Shows how to use SharePoint grounding with Azure AI agents to search through SharePoint content and answer user questions with proper citations. Requires a SharePoint connection configured in your Azure AI project. | -| [`azure_ai_with_thread.py`](azure_ai_with_thread.py) | Demonstrates thread management with Azure AI agents, including automatic thread creation for stateless conversations and explicit thread management for maintaining conversation context across multiple interactions. | +| [`azure_ai_with_thread.py`](azure_ai_with_thread.py) | Demonstrates session management with Azure AI agents, including automatic session creation for stateless conversations and explicit session management for maintaining conversation context across multiple interactions. | | [`azure_ai_with_image_generation.py`](azure_ai_with_image_generation.py) | Shows how to use `AzureAIClient.get_image_generation_tool()` with Azure AI agents to generate images based on text prompts. | | [`azure_ai_with_memory_search.py`](azure_ai_with_memory_search.py) | Shows how to use memory search functionality with Azure AI agents for conversation persistence. Demonstrates creating memory stores and enabling agents to search through conversation history. | | [`azure_ai_with_microsoft_fabric.py`](azure_ai_with_microsoft_fabric.py) | Shows how to use Microsoft Fabric with Azure AI agents to query Fabric data sources and provide responses based on data analysis. Requires a Microsoft Fabric connection configured in your Azure AI project. | @@ -92,4 +92,4 @@ python azure_ai_with_code_interpreter.py # ... etc ``` -The examples demonstrate various patterns for working with Azure AI agents, from basic usage to advanced scenarios like thread management and structured outputs. +The examples demonstrate various patterns for working with Azure AI agents, from basic usage to advanced scenarios like session management and structured outputs. diff --git a/python/samples/02-agents/providers/azure_ai_agent/README.md b/python/samples/02-agents/providers/azure_ai_agent/README.md index c91a66d558..a69572ac9d 100644 --- a/python/samples/02-agents/providers/azure_ai_agent/README.md +++ b/python/samples/02-agents/providers/azure_ai_agent/README.md @@ -38,7 +38,7 @@ async with ( | [`azure_ai_with_code_interpreter_file_generation.py`](azure_ai_with_code_interpreter_file_generation.py) | Shows how to retrieve file IDs from code interpreter generated files using both streaming and non-streaming approaches. | | [`azure_ai_with_code_interpreter.py`](azure_ai_with_code_interpreter.py) | Shows how to use `AzureAIAgentClient.get_code_interpreter_tool()` with Azure AI agents to write and execute Python code. Includes helper methods for accessing code interpreter data from response chunks. | | [`azure_ai_with_existing_agent.py`](azure_ai_with_existing_agent.py) | Shows how to work with an existing SDK Agent object using `provider.as_agent()`. This wraps the agent without making HTTP calls. | -| [`azure_ai_with_existing_thread.py`](azure_ai_with_existing_thread.py) | Shows how to work with a pre-existing thread by providing the thread ID. Demonstrates proper cleanup of manually created threads. | +| [`azure_ai_with_existing_thread.py`](azure_ai_with_existing_thread.py) | Shows how to work with a pre-existing session by providing the session ID. Demonstrates proper cleanup of manually created sessions. | | [`azure_ai_with_explicit_settings.py`](azure_ai_with_explicit_settings.py) | Shows how to create an agent with explicitly configured provider settings, including project endpoint and model deployment name. | | [`azure_ai_with_azure_ai_search.py`](azure_ai_with_azure_ai_search.py) | Demonstrates how to use Azure AI Search with Azure AI agents. Shows how to create an agent with search tools using the SDK directly and wrap it with `provider.get_agent()`. | | [`azure_ai_with_file_search.py`](azure_ai_with_file_search.py) | Demonstrates how to use `AzureAIAgentClient.get_file_search_tool()` with Azure AI agents to search through uploaded documents. Shows file upload, vector store creation, and querying document content. | @@ -46,9 +46,9 @@ async with ( | [`azure_ai_with_hosted_mcp.py`](azure_ai_with_hosted_mcp.py) | Shows how to use `AzureAIAgentClient.get_mcp_tool()` with hosted Model Context Protocol (MCP) servers for enhanced functionality and tool integration. Demonstrates remote MCP server connections and tool discovery. | | [`azure_ai_with_local_mcp.py`](azure_ai_with_local_mcp.py) | Shows how to integrate Azure AI agents with local Model Context Protocol (MCP) servers for enhanced functionality and tool integration. Demonstrates both agent-level and run-level tool configuration. | | [`azure_ai_with_multiple_tools.py`](azure_ai_with_multiple_tools.py) | Demonstrates how to use multiple tools together with Azure AI agents, including web search, MCP servers, and function tools using client static methods. Shows coordinated multi-tool interactions and approval workflows. | -| [`azure_ai_with_openapi_tools.py`](azure_ai_with_openapi_tools.py) | Demonstrates how to use OpenAPI tools with Azure AI agents to integrate external REST APIs. Shows OpenAPI specification loading, anonymous authentication, thread context management, and coordinated multi-API conversations. | +| [`azure_ai_with_openapi_tools.py`](azure_ai_with_openapi_tools.py) | Demonstrates how to use OpenAPI tools with Azure AI agents to integrate external REST APIs. Shows OpenAPI specification loading, anonymous authentication, session context management, and coordinated multi-API conversations. | | [`azure_ai_with_response_format.py`](azure_ai_with_response_format.py) | Demonstrates how to use structured outputs with Azure AI agents using Pydantic models. | -| [`azure_ai_with_thread.py`](azure_ai_with_thread.py) | Demonstrates thread management with Azure AI agents, including automatic thread creation for stateless conversations and explicit thread management for maintaining conversation context across multiple interactions. | +| [`azure_ai_with_thread.py`](azure_ai_with_thread.py) | Demonstrates session management with Azure AI agents, including automatic session creation for stateless conversations and explicit session management for maintaining conversation context across multiple interactions. | ## Environment Variables diff --git a/python/samples/02-agents/providers/azure_openai/README.md b/python/samples/02-agents/providers/azure_openai/README.md index 614e60b14d..460c2861fe 100644 --- a/python/samples/02-agents/providers/azure_openai/README.md +++ b/python/samples/02-agents/providers/azure_openai/README.md @@ -11,11 +11,11 @@ This folder contains examples demonstrating different ways to create and use age | [`azure_assistants_with_existing_assistant.py`](azure_assistants_with_existing_assistant.py) | Shows how to work with a pre-existing assistant by providing the assistant ID to the Azure Assistants client. Demonstrates proper cleanup of manually created assistants. | | [`azure_assistants_with_explicit_settings.py`](azure_assistants_with_explicit_settings.py) | Shows how to initialize an agent with a specific assistants client, configuring settings explicitly including endpoint and deployment name. | | [`azure_assistants_with_function_tools.py`](azure_assistants_with_function_tools.py) | Demonstrates how to use function tools with agents. Shows both agent-level tools (defined when creating the agent) and query-level tools (provided with specific queries). | -| [`azure_assistants_with_thread.py`](azure_assistants_with_thread.py) | Demonstrates thread management with Azure agents, including automatic thread creation for stateless conversations and explicit thread management for maintaining conversation context across multiple interactions. | +| [`azure_assistants_with_thread.py`](azure_assistants_with_thread.py) | Demonstrates session management with Azure agents, including automatic session creation for stateless conversations and explicit session management for maintaining conversation context across multiple interactions. | | [`azure_chat_client_basic.py`](azure_chat_client_basic.py) | The simplest way to create an agent using `Agent` with `AzureOpenAIChatClient`. Shows both streaming and non-streaming responses for chat-based interactions with Azure OpenAI models. | | [`azure_chat_client_with_explicit_settings.py`](azure_chat_client_with_explicit_settings.py) | Shows how to initialize an agent with a specific chat client, configuring settings explicitly including endpoint and deployment name. | | [`azure_chat_client_with_function_tools.py`](azure_chat_client_with_function_tools.py) | Demonstrates how to use function tools with agents. Shows both agent-level tools (defined when creating the agent) and query-level tools (provided with specific queries). | -| [`azure_chat_client_with_thread.py`](azure_chat_client_with_thread.py) | Demonstrates thread management with Azure agents, including automatic thread creation for stateless conversations and explicit thread management for maintaining conversation context across multiple interactions. | +| [`azure_chat_client_with_thread.py`](azure_chat_client_with_thread.py) | Demonstrates session management with Azure agents, including automatic session creation for stateless conversations and explicit session management for maintaining conversation context across multiple interactions. | | [`azure_responses_client_basic.py`](azure_responses_client_basic.py) | The simplest way to create an agent using `Agent` with `AzureOpenAIResponsesClient`. Shows both streaming and non-streaming responses for structured response generation with Azure OpenAI models. | | [`azure_responses_client_code_interpreter_files.py`](azure_responses_client_code_interpreter_files.py) | Demonstrates using `AzureOpenAIResponsesClient.get_code_interpreter_tool()` with file uploads for data analysis. Shows how to create, upload, and analyze CSV files using Python code execution with Azure OpenAI Responses. | | [`azure_responses_client_image_analysis.py`](azure_responses_client_image_analysis.py) | Shows how to use Azure OpenAI Responses for image analysis and vision tasks. Demonstrates multi-modal messages combining text and image content using remote URLs. | @@ -26,7 +26,7 @@ This folder contains examples demonstrating different ways to create and use age | [`azure_responses_client_with_function_tools.py`](azure_responses_client_with_function_tools.py) | Demonstrates how to use function tools with agents. Shows both agent-level tools (defined when creating the agent) and query-level tools (provided with specific queries). | | [`azure_responses_client_with_hosted_mcp.py`](azure_responses_client_with_hosted_mcp.py) | Shows how to integrate Azure OpenAI Responses Client with hosted Model Context Protocol (MCP) servers using `AzureOpenAIResponsesClient.get_mcp_tool()` for extended functionality. | | [`azure_responses_client_with_local_mcp.py`](azure_responses_client_with_local_mcp.py) | Shows how to integrate Azure OpenAI Responses Client with local Model Context Protocol (MCP) servers using MCPStreamableHTTPTool for extended functionality. | -| [`azure_responses_client_with_thread.py`](azure_responses_client_with_thread.py) | Demonstrates thread management with Azure agents, including automatic thread creation for stateless conversations and explicit thread management for maintaining conversation context across multiple interactions. | +| [`azure_responses_client_with_thread.py`](azure_responses_client_with_thread.py) | Demonstrates session management with Azure agents, including automatic session creation for stateless conversations and explicit session management for maintaining conversation context across multiple interactions. | ## Environment Variables diff --git a/python/samples/02-agents/providers/custom/README.md b/python/samples/02-agents/providers/custom/README.md index f8921b1f24..f2d67e0315 100644 --- a/python/samples/02-agents/providers/custom/README.md +++ b/python/samples/02-agents/providers/custom/README.md @@ -6,7 +6,7 @@ This folder contains examples demonstrating how to implement custom agents and c | File | Description | |------|-------------| -| [`custom_agent.py`](custom_agent.py) | Shows how to create custom agents by extending the `BaseAgent` class. Demonstrates the `EchoAgent` implementation with both streaming and non-streaming responses, proper thread management, and message history handling. | +| [`custom_agent.py`](custom_agent.py) | Shows how to create custom agents by extending the `BaseAgent` class. Demonstrates the `EchoAgent` implementation with both streaming and non-streaming responses, proper session management, and message history handling. | | [`custom_chat_client.py`](../../chat_client/custom_chat_client.py) | Demonstrates how to create custom chat clients by extending the `BaseChatClient` class. Shows a `EchoingChatClient` implementation and how to integrate it with `Agent` using the `as_agent()` method. | ## Key Takeaways @@ -15,7 +15,7 @@ This folder contains examples demonstrating how to implement custom agents and c - Custom agents give you complete control over the agent's behavior - You must implement both `run()` for both the `stream=True` and `stream=False` cases - Use `self._normalize_messages()` to handle different input message formats -- Use `self._notify_thread_of_new_messages()` to properly manage conversation history +- Store messages in `session.state` to properly manage conversation history ### Custom Chat Clients - Custom chat clients allow you to integrate any backend service or create new LLM providers diff --git a/python/samples/02-agents/providers/github_copilot/README.md b/python/samples/02-agents/providers/github_copilot/README.md index c69ffe37eb..572ec9c444 100644 --- a/python/samples/02-agents/providers/github_copilot/README.md +++ b/python/samples/02-agents/providers/github_copilot/README.md @@ -29,7 +29,7 @@ The following environment variables can be configured: | File | Description | |------|-------------| | [`github_copilot_basic.py`](github_copilot_basic.py) | The simplest way to create an agent using `GitHubCopilotAgent`. Demonstrates both streaming and non-streaming responses with function tools. | -| [`github_copilot_with_session.py`](github_copilot_with_session.py) | Shows session management with automatic creation, persistence via thread objects, and resuming sessions by ID. | +| [`github_copilot_with_session.py`](github_copilot_with_session.py) | Shows session management with automatic creation, persistence via session objects, and resuming sessions by ID. | | [`github_copilot_with_shell.py`](github_copilot_with_shell.py) | Shows how to enable shell command execution permissions. Demonstrates running system commands like listing files and getting system information. | | [`github_copilot_with_file_operations.py`](github_copilot_with_file_operations.py) | Shows how to enable file read and write permissions. Demonstrates reading file contents and creating new files. | | [`github_copilot_with_url.py`](github_copilot_with_url.py) | Shows how to enable URL fetching permissions. Demonstrates fetching and processing web content. | diff --git a/python/samples/02-agents/providers/openai/README.md b/python/samples/02-agents/providers/openai/README.md index 579bfec187..bfcdc94d90 100644 --- a/python/samples/02-agents/providers/openai/README.md +++ b/python/samples/02-agents/providers/openai/README.md @@ -14,12 +14,12 @@ This folder contains examples demonstrating different ways to create and use age | [`openai_assistants_with_file_search.py`](openai_assistants_with_file_search.py) | Using `OpenAIAssistantsClient.get_file_search_tool()` with `OpenAIAssistantProvider` for file search capabilities. | | [`openai_assistants_with_function_tools.py`](openai_assistants_with_function_tools.py) | Function tools with `OpenAIAssistantProvider` at both agent-level and query-level. | | [`openai_assistants_with_response_format.py`](openai_assistants_with_response_format.py) | Structured outputs with `OpenAIAssistantProvider` using Pydantic models. | -| [`openai_assistants_with_thread.py`](openai_assistants_with_thread.py) | Thread management with `OpenAIAssistantProvider` for conversation context persistence. | +| [`openai_assistants_with_thread.py`](openai_assistants_with_thread.py) | Session management with `OpenAIAssistantProvider` for conversation context persistence. | | [`openai_chat_client_basic.py`](openai_chat_client_basic.py) | The simplest way to create an agent using `Agent` with `OpenAIChatClient`. Shows both streaming and non-streaming responses for chat-based interactions with OpenAI models. | | [`openai_chat_client_with_explicit_settings.py`](openai_chat_client_with_explicit_settings.py) | Shows how to initialize an agent with a specific chat client, configuring settings explicitly including API key and model ID. | | [`openai_chat_client_with_function_tools.py`](openai_chat_client_with_function_tools.py) | Demonstrates how to use function tools with agents. Shows both agent-level tools (defined when creating the agent) and query-level tools (provided with specific queries). | | [`openai_chat_client_with_local_mcp.py`](openai_chat_client_with_local_mcp.py) | Shows how to integrate OpenAI agents with local Model Context Protocol (MCP) servers for enhanced functionality and tool integration. | -| [`openai_chat_client_with_thread.py`](openai_chat_client_with_thread.py) | Demonstrates thread management with OpenAI agents, including automatic thread creation for stateless conversations and explicit thread management for maintaining conversation context across multiple interactions. | +| [`openai_chat_client_with_thread.py`](openai_chat_client_with_thread.py) | Demonstrates session management with OpenAI agents, including automatic session creation for stateless conversations and explicit session management for maintaining conversation context across multiple interactions. | | [`openai_chat_client_with_web_search.py`](openai_chat_client_with_web_search.py) | Shows how to use `OpenAIChatClient.get_web_search_tool()` for web search capabilities with OpenAI agents. | | [`openai_chat_client_with_runtime_json_schema.py`](openai_chat_client_with_runtime_json_schema.py) | Shows how to supply a runtime JSON Schema via `additional_chat_options` for structured output without defining a Pydantic model. | | [`openai_responses_client_basic.py`](openai_responses_client_basic.py) | The simplest way to create an agent using `Agent` with `OpenAIResponsesClient`. Shows both streaming and non-streaming responses for structured response generation with OpenAI models. | @@ -37,7 +37,7 @@ This folder contains examples demonstrating different ways to create and use age | [`openai_responses_client_with_local_mcp.py`](openai_responses_client_with_local_mcp.py) | Shows how to integrate OpenAI agents with local Model Context Protocol (MCP) servers for enhanced functionality and tool integration. | | [`openai_responses_client_with_runtime_json_schema.py`](openai_responses_client_with_runtime_json_schema.py) | Shows how to supply a runtime JSON Schema via `additional_chat_options` for structured output without defining a Pydantic model. | | [`openai_responses_client_with_structured_output.py`](openai_responses_client_with_structured_output.py) | Demonstrates how to use structured outputs with OpenAI agents to get structured data responses in predefined formats. | -| [`openai_responses_client_with_thread.py`](openai_responses_client_with_thread.py) | Demonstrates thread management with OpenAI agents, including automatic thread creation for stateless conversations and explicit thread management for maintaining conversation context across multiple interactions. | +| [`openai_responses_client_with_thread.py`](openai_responses_client_with_thread.py) | Demonstrates session management with OpenAI agents, including automatic session creation for stateless conversations and explicit session management for maintaining conversation context across multiple interactions. | | [`openai_responses_client_with_web_search.py`](openai_responses_client_with_web_search.py) | Shows how to use `OpenAIResponsesClient.get_web_search_tool()` for web search capabilities. | ## Environment Variables diff --git a/python/samples/03-workflows/README.md b/python/samples/03-workflows/README.md index 0c9f2c5df8..26eccd03e4 100644 --- a/python/samples/03-workflows/README.md +++ b/python/samples/03-workflows/README.md @@ -36,7 +36,7 @@ Once comfortable with these, explore the rest of the samples below. | -------------------------------------- | -------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------- | | Azure Chat Agents (Streaming) | [agents/azure_chat_agents_streaming.py](./agents/azure_chat_agents_streaming.py) | Add Azure Chat agents as edges and handle streaming events | | Azure AI Agents (Streaming) | [agents/azure_ai_agents_streaming.py](./agents/azure_ai_agents_streaming.py) | Add Azure AI agents as edges and handle streaming events | -| Azure AI Agents (Shared Thread) | [agents/azure_ai_agents_with_shared_thread.py](./agents/azure_ai_agents_with_shared_thread.py) | Share a common message thread between multiple Azure AI agents in a workflow | +| Azure AI Agents (Shared Thread) | [agents/azure_ai_agents_with_shared_thread.py](./agents/azure_ai_agents_with_shared_thread.py) | Share a common message session between multiple Azure AI agents in a workflow | | Custom Agent Executors | [agents/custom_agent_executors.py](./agents/custom_agent_executors.py) | Create executors to handle agent run methods | | Workflow as Agent (Reflection Pattern) | [agents/workflow_as_agent_reflection_pattern.py](./agents/workflow_as_agent_reflection_pattern.py) | Wrap a workflow so it can behave like an agent (reflection pattern) | | Workflow as Agent + HITL | [agents/workflow_as_agent_human_in_the_loop.py](./agents/workflow_as_agent_human_in_the_loop.py) | Extend workflow-as-agent with human-in-the-loop capability | diff --git a/python/samples/04-hosting/azure_functions/01_single_agent/README.md b/python/samples/04-hosting/azure_functions/01_single_agent/README.md index 38c6ce58f5..886f1156a0 100644 --- a/python/samples/04-hosting/azure_functions/01_single_agent/README.md +++ b/python/samples/04-hosting/azure_functions/01_single_agent/README.md @@ -7,7 +7,7 @@ This sample demonstrates how to use the Durable Extension for Agent Framework to - Defining a simple agent with the Microsoft Agent Framework and wiring it into an Azure Functions app via the Durable Extension for Agent Framework. - Calling the agent through generated HTTP endpoints (`/api/agents/Joker/run`). -- Managing conversation state with thread identifiers, so multiple clients can +- Managing conversation state with session identifiers, so multiple clients can interact with the agent concurrently without sharing context. ## Prerequisites diff --git a/python/samples/04-hosting/azure_functions/02_multi_agent/README.md b/python/samples/04-hosting/azure_functions/02_multi_agent/README.md index e10b9d4d51..e133ca369c 100644 --- a/python/samples/04-hosting/azure_functions/02_multi_agent/README.md +++ b/python/samples/04-hosting/azure_functions/02_multi_agent/README.md @@ -6,7 +6,7 @@ This sample demonstrates how to use the Durable Extension for Agent Framework to - Using the Microsoft Agent Framework to define multiple AI agents with unique names and instructions. - Registering multiple agents with the Function app and running them using HTTP. -- Conversation management (via thread IDs) for isolated interactions per agent. +- Conversation management (via session IDs) for isolated interactions per agent. - Two different methods for registering agents: list-based initialization and incremental addition. ## Prerequisites diff --git a/python/samples/04-hosting/azure_functions/04_single_agent_orchestration_chaining/README.md b/python/samples/04-hosting/azure_functions/04_single_agent_orchestration_chaining/README.md index 13e8c08429..332c03d378 100644 --- a/python/samples/04-hosting/azure_functions/04_single_agent_orchestration_chaining/README.md +++ b/python/samples/04-hosting/azure_functions/04_single_agent_orchestration_chaining/README.md @@ -4,8 +4,8 @@ This sample shows how to chain two invocations of the same agent inside a Durabl preserving the conversation state between runs. ## Key Concepts -- Deterministic orchestrations that make sequential agent calls on a shared thread -- Reusing an agent thread to carry conversation history across invocations +- Deterministic orchestrations that make sequential agent calls on a shared session +- Reusing an agent session to carry conversation history across invocations - HTTP endpoints for starting the orchestration and polling for status/output ## Prerequisites diff --git a/python/samples/04-hosting/durabletask/01_single_agent/README.md b/python/samples/04-hosting/durabletask/01_single_agent/README.md index 62a150e216..2b8ce83c13 100644 --- a/python/samples/04-hosting/durabletask/01_single_agent/README.md +++ b/python/samples/04-hosting/durabletask/01_single_agent/README.md @@ -6,7 +6,7 @@ This sample demonstrates how to create a worker-client setup that hosts a single - Using the Microsoft Agent Framework to define a simple AI agent with a name and instructions. - Registering durable agents with the worker and interacting with them via a client. -- Conversation management (via threads) for isolated interactions. +- Conversation management (via sessions) for isolated interactions. - Worker-client architecture for distributed agent execution. ## Environment Setup @@ -46,7 +46,7 @@ Using taskhub: default Using endpoint: http://localhost:8080 Getting reference to Joker agent... -Created conversation thread: a1b2c3d4-e5f6-7890-abcd-ef1234567890 +Created conversation session: a1b2c3d4-e5f6-7890-abcd-ef1234567890 User: Tell me a short joke about cloud computing. diff --git a/python/samples/04-hosting/durabletask/02_multi_agent/README.md b/python/samples/04-hosting/durabletask/02_multi_agent/README.md index aad51ba014..b2989579e8 100644 --- a/python/samples/04-hosting/durabletask/02_multi_agent/README.md +++ b/python/samples/04-hosting/durabletask/02_multi_agent/README.md @@ -6,7 +6,7 @@ This sample demonstrates how to host multiple AI agents with different tools in - Hosting multiple agents (WeatherAgent and MathAgent) in a single worker process. - Each agent with its own specialized tools and instructions. -- Interacting with different agents using separate conversation threads. +- Interacting with different agents using separate conversation sessions. - Worker-client architecture for multi-agent systems. ## Environment Setup @@ -49,7 +49,7 @@ Using endpoint: http://localhost:8080 Testing WeatherAgent ================================================================================ -Created weather conversation thread: +Created weather conversation session: User: What is the weather in Seattle? 🔧 [TOOL CALLED] get_weather(location=Seattle) @@ -61,7 +61,7 @@ WeatherAgent: The current weather in Seattle is sunny with a temperature of 72° Testing MathAgent ================================================================================ -Created math conversation thread: +Created math conversation session: User: Calculate a 20% tip on a $50 bill 🔧 [TOOL CALLED] calculate_tip(bill_amount=50.0, tip_percentage=20.0) diff --git a/python/samples/04-hosting/durabletask/04_single_agent_orchestration_chaining/README.md b/python/samples/04-hosting/durabletask/04_single_agent_orchestration_chaining/README.md index 2c277423f1..4d015c28dc 100644 --- a/python/samples/04-hosting/durabletask/04_single_agent_orchestration_chaining/README.md +++ b/python/samples/04-hosting/durabletask/04_single_agent_orchestration_chaining/README.md @@ -6,7 +6,7 @@ This sample demonstrates how to chain multiple invocations of the same agent usi - Using durable orchestrations to coordinate sequential agent invocations. - Chaining agent calls where the output of one run becomes input to the next. -- Maintaining conversation context across sequential runs using a shared thread. +- Maintaining conversation context across sequential runs using a shared session. - Using `DurableAIAgentOrchestrationContext` to access agents within orchestrations. ## Environment Setup @@ -42,7 +42,7 @@ The orchestration will execute the writer agent twice sequentially: ``` [Orchestration] Starting single agent chaining... -[Orchestration] Created thread: abc-123 +[Orchestration] Created session: abc-123 [Orchestration] First agent run: Generating initial sentence... [Orchestration] Initial response: Every small step forward is progress toward mastery. [Orchestration] Second agent run: Refining the sentence... @@ -62,7 +62,7 @@ You can view the state of the orchestration in the Durable Task Scheduler dashbo 1. Open your browser and navigate to `http://localhost:8082` 2. In the dashboard, you can view: - The sequential execution of both agent runs - - The conversation thread shared between runs + - The conversation session shared between runs - Input and output at each step - Overall orchestration state and history diff --git a/python/samples/04-hosting/durabletask/05_multi_agent_orchestration_concurrency/README.md b/python/samples/04-hosting/durabletask/05_multi_agent_orchestration_concurrency/README.md index e2843f4798..e94602d822 100644 --- a/python/samples/04-hosting/durabletask/05_multi_agent_orchestration_concurrency/README.md +++ b/python/samples/04-hosting/durabletask/05_multi_agent_orchestration_concurrency/README.md @@ -7,7 +7,7 @@ This sample demonstrates how to host multiple agents and run them concurrently u - Running multiple specialized agents in parallel within an orchestration. - Using `OrchestrationAgentExecutor` to get `DurableAgentTask` objects for concurrent execution. - Aggregating results from multiple agents using `task.when_all()`. -- Creating separate conversation threads for independent agent contexts. +- Creating separate conversation sessions for independent agent contexts. ## Environment Setup @@ -64,7 +64,7 @@ You can view the state of the orchestration in the Durable Task Scheduler dashbo 1. Open your browser and navigate to `http://localhost:8082` 2. In the dashboard, you can view: - The concurrent execution of both agents (PhysicistAgent and ChemistAgent) - - Separate conversation threads for each agent + - Separate conversation sessions for each agent - Parallel task execution and completion timing - Aggregated results from both agents diff --git a/python/samples/04-hosting/durabletask/07_single_agent_orchestration_hitl/README.md b/python/samples/04-hosting/durabletask/07_single_agent_orchestration_hitl/README.md index 5e3f8eea27..9f47d51009 100644 --- a/python/samples/04-hosting/durabletask/07_single_agent_orchestration_hitl/README.md +++ b/python/samples/04-hosting/durabletask/07_single_agent_orchestration_hitl/README.md @@ -82,6 +82,6 @@ You can view the state of the WriterAgent and orchestration in the Durable Task 1. Open your browser and navigate to `http://localhost:8082` 2. In the dashboard, you can view: - Orchestration instance status and pending events - - WriterAgent entity state and conversation threads + - WriterAgent entity state and conversation sessions - Activity execution logs - External event history diff --git a/python/samples/autogen-migration/README.md b/python/samples/autogen-migration/README.md index 39e6afd582..2bfa229183 100644 --- a/python/samples/autogen-migration/README.md +++ b/python/samples/autogen-migration/README.md @@ -8,7 +8,7 @@ This gallery helps AutoGen developers move to the Microsoft Agent Framework (AF) - [01_basic_assistant_agent.py](single_agent/01_basic_assistant_agent.py) — Minimal AutoGen `AssistantAgent` and AF `Agent` comparison. - [02_assistant_agent_with_tool.py](single_agent/02_assistant_agent_with_tool.py) — Function tool integration in both SDKs. -- [03_assistant_agent_thread_and_stream.py](single_agent/03_assistant_agent_thread_and_stream.py) — Thread management and streaming responses. +- [03_assistant_agent_thread_and_stream.py](single_agent/03_assistant_agent_thread_and_stream.py) — Session management and streaming responses. - [04_agent_as_tool.py](single_agent/04_agent_as_tool.py) — Using agents as tools (hierarchical agent pattern) and streaming with tools. ### Multi-Agent Orchestration diff --git a/python/samples/getting_started/sessions/README.md b/python/samples/getting_started/sessions/README.md index 6910aab677..daee274c2c 100644 --- a/python/samples/getting_started/sessions/README.md +++ b/python/samples/getting_started/sessions/README.md @@ -15,9 +15,9 @@ Sessions and context providers are the core building blocks for agent memory in | File | Description | |------|-------------| -| [`suspend_resume_thread.py`](suspend_resume_thread.py) | Suspend and resume sessions via `to_dict()` / `from_dict()` — both service-managed (Azure AI) and in-memory (OpenAI). | -| [`custom_chat_message_store_thread.py`](custom_chat_message_store_thread.py) | Implement a custom `BaseHistoryProvider` with dict-based storage. Shows serialization/deserialization. | -| [`redis_chat_message_store_thread.py`](redis_chat_message_store_thread.py) | `RedisHistoryProvider` for persistent storage: basic usage, user sessions, persistence across restarts, serialization, and message trimming. | +| [`suspend_resume_session.py`](suspend_resume_session.py) | Suspend and resume sessions via `to_dict()` / `from_dict()` — both service-managed (Azure AI) and in-memory (OpenAI). | +| [`custom_history_provider.py`](custom_history_provider.py) | Implement a custom `BaseHistoryProvider` with dict-based storage. Shows serialization/deserialization. | +| [`redis_history_provider.py`](redis_history_provider.py) | `RedisHistoryProvider` for persistent storage: basic usage, user sessions, persistence across restarts, serialization, and message trimming. | ### Custom Context Providers @@ -37,7 +37,7 @@ Sessions and context providers are the core building blocks for agent memory in | File | Description | |------|-------------| | [`mem0/mem0_basic.py`](mem0/mem0_basic.py) | Basic Mem0 integration for user preference memory. | -| [`mem0/mem0_threads.py`](mem0/mem0_threads.py) | Thread scoping: global scope, per-operation scope, and multi-agent isolation. | +| [`mem0/mem0_sessions.py`](mem0/mem0_sessions.py) | Session scoping: global scope, per-operation scope, and multi-agent isolation. | | [`mem0/mem0_oss.py`](mem0/mem0_oss.py) | Mem0 Open Source (self-hosted) integration. | ### Redis @@ -46,7 +46,7 @@ Sessions and context providers are the core building blocks for agent memory in |------|-------------| | [`redis/redis_basics.py`](redis/redis_basics.py) | Standalone provider usage, full-text/hybrid search, preferences, and tool output memory. | | [`redis/redis_conversation.py`](redis/redis_conversation.py) | Conversation persistence across sessions. | -| [`redis/redis_threads.py`](redis/redis_threads.py) | Thread scoping: global, per-operation, and multi-agent isolation. | +| [`redis/redis_sessions.py`](redis/redis_sessions.py) | Session scoping: global, per-operation, and multi-agent isolation. | | [`redis/azure_redis_conversation.py`](redis/azure_redis_conversation.py) | Azure Managed Redis with Entra ID authentication. | ## Choosing a Provider @@ -100,4 +100,4 @@ class MyHistoryProvider(BaseHistoryProvider): ... # Persist to your storage ``` -See `custom_chat_message_store_thread.py` and `simple_context_provider.py` for complete examples. +See `custom_history_provider.py` and `simple_context_provider.py` for complete examples. diff --git a/python/samples/getting_started/sessions/custom_chat_message_store_thread.py b/python/samples/getting_started/sessions/custom_history_provider.py similarity index 100% rename from python/samples/getting_started/sessions/custom_chat_message_store_thread.py rename to python/samples/getting_started/sessions/custom_history_provider.py diff --git a/python/samples/getting_started/sessions/mem0/README.md b/python/samples/getting_started/sessions/mem0/README.md index 61d8bbd51f..667455a536 100644 --- a/python/samples/getting_started/sessions/mem0/README.md +++ b/python/samples/getting_started/sessions/mem0/README.md @@ -8,8 +8,8 @@ This folder contains examples demonstrating how to use the Mem0 context provider | File | Description | |------|-------------| -| [`mem0_basic.py`](mem0_basic.py) | Basic example of using Mem0 context provider to store and retrieve user preferences across different conversation threads. | -| [`mem0_threads.py`](mem0_threads.py) | Advanced example demonstrating different thread scoping strategies with Mem0. Covers global thread scope (memories shared across all operations), per-operation thread scope (memories isolated per thread), and multiple agents with different memory configurations for personal vs. work contexts. | +| [`mem0_basic.py`](mem0_basic.py) | Basic example of using Mem0 context provider to store and retrieve user preferences across different conversation sessions. | +| [`mem0_sessions.py`](mem0_sessions.py) | Advanced example demonstrating different session scoping strategies with Mem0. Covers global session scope (memories shared across all operations), per-operation session scope (memories isolated per session), and multiple agents with different memory configurations for personal vs. work contexts. | | [`mem0_oss.py`](mem0_oss.py) | Example of using the Mem0 Open Source self-hosted version as the context provider. Demonstrates setup and configuration for local deployment. | ## Prerequisites @@ -42,8 +42,8 @@ Set the following environment variables: The Mem0 context provider supports different scoping strategies: -- **Global Scope** (`scope_to_per_operation_thread_id=False`): Memories are shared across all conversation threads -- **Thread Scope** (`scope_to_per_operation_thread_id=True`): Memories are isolated per conversation thread +- **Global Scope** (`scope_to_per_operation_thread_id=False`): Memories are shared across all conversation sessions +- **Session Scope** (`scope_to_per_operation_thread_id=True`): Memories are isolated per conversation session ### Memory Association @@ -51,5 +51,5 @@ Mem0 records can be associated with different identifiers: - `user_id`: Associate memories with a specific user - `agent_id`: Associate memories with a specific agent -- `thread_id`: Associate memories with a specific conversation thread +- `thread_id`: Associate memories with a specific conversation session - `application_id`: Associate memories with an application context diff --git a/python/samples/getting_started/sessions/mem0/mem0_threads.py b/python/samples/getting_started/sessions/mem0/mem0_sessions.py similarity index 100% rename from python/samples/getting_started/sessions/mem0/mem0_threads.py rename to python/samples/getting_started/sessions/mem0/mem0_sessions.py diff --git a/python/samples/getting_started/sessions/redis/README.md b/python/samples/getting_started/sessions/redis/README.md index dec2c77485..03c41295f3 100644 --- a/python/samples/getting_started/sessions/redis/README.md +++ b/python/samples/getting_started/sessions/redis/README.md @@ -1,6 +1,6 @@ # Redis Context Provider Examples -The Redis context provider enables persistent, searchable memory for your agents using Redis (RediSearch). It supports full‑text search and optional hybrid search with vector embeddings, letting agents remember and retrieve user context across sessions and threads. +The Redis context provider enables persistent, searchable memory for your agents using Redis (RediSearch). It supports full‑text search and optional hybrid search with vector embeddings, letting agents remember and retrieve user context across sessions. This folder contains an example demonstrating how to use the Redis context provider with the Agent Framework. @@ -9,9 +9,9 @@ This folder contains an example demonstrating how to use the Redis context provi | File | Description | |------|-------------| | [`azure_redis_conversation.py`](azure_redis_conversation.py) | Demonstrates conversation persistence with RedisHistoryProvider and Azure Redis with Azure AD (Entra ID) authentication using credential provider. | -| [`redis_basics.py`](redis_basics.py) | Shows standalone provider usage and agent integration. Demonstrates writing messages to Redis, retrieving context via full‑text or hybrid vector search, and persisting preferences across threads. Also includes a simple tool example whose outputs are remembered. | +| [`redis_basics.py`](redis_basics.py) | Shows standalone provider usage and agent integration. Demonstrates writing messages to Redis, retrieving context via full‑text or hybrid vector search, and persisting preferences across sessions. Also includes a simple tool example whose outputs are remembered. | | [`redis_conversation.py`](redis_conversation.py) | Simple example showing conversation persistence with RedisContextProvider using traditional connection string authentication. | -| [`redis_threads.py`](redis_threads.py) | Demonstrates thread scoping. Includes: (1) global thread scope with a fixed `thread_id` shared across operations; (2) per‑operation thread scope where `scope_to_per_operation_thread_id=True` binds memory to a single thread for the provider's lifetime; and (3) multiple agents with isolated memory via different `agent_id` values. | +| [`redis_sessions.py`](redis_sessions.py) | Demonstrates session scoping. Includes: (1) global session scope with a fixed `thread_id` shared across operations; (2) per‑operation session scope where `scope_to_per_operation_thread_id=True` binds memory to a single session for the provider's lifetime; and (3) multiple agents with isolated memory via different `agent_id` values. | ## Prerequisites @@ -59,7 +59,7 @@ The provider supports both full‑text only and hybrid vector search: - Set `vectorizer_choice` to `"openai"` or `"hf"` to enable embeddings and hybrid search. - When using a vectorizer, also set `vector_field_name` (e.g., `"vector"`). - Partition fields for scoping memory: `application_id`, `agent_id`, `user_id`, `thread_id`. -- Thread scoping: `scope_to_per_operation_thread_id=True` isolates memory per operation thread. +- Session scoping: `scope_to_per_operation_thread_id=True` isolates memory per operation session. - Index management: `index_name`, `overwrite_redis_index`, `drop_redis_index`. ## What the example does @@ -95,7 +95,7 @@ You should see the agent responses and, when using embeddings, context retrieved ### Memory scoping - Global scope: set `application_id`, `agent_id`, `user_id`, or `thread_id` on the provider to filter memory. -- Per‑operation thread scope: set `scope_to_per_operation_thread_id=True` to isolate memory to the current thread created by the framework. +- Per‑operation session scope: set `scope_to_per_operation_thread_id=True` to isolate memory to the current session created by the framework. ### Hybrid vector search (optional) diff --git a/python/samples/getting_started/sessions/redis/redis_threads.py b/python/samples/getting_started/sessions/redis/redis_sessions.py similarity index 100% rename from python/samples/getting_started/sessions/redis/redis_threads.py rename to python/samples/getting_started/sessions/redis/redis_sessions.py diff --git a/python/samples/getting_started/sessions/redis_chat_message_store_thread.py b/python/samples/getting_started/sessions/redis_history_provider.py similarity index 100% rename from python/samples/getting_started/sessions/redis_chat_message_store_thread.py rename to python/samples/getting_started/sessions/redis_history_provider.py diff --git a/python/samples/getting_started/sessions/suspend_resume_thread.py b/python/samples/getting_started/sessions/suspend_resume_session.py similarity index 100% rename from python/samples/getting_started/sessions/suspend_resume_thread.py rename to python/samples/getting_started/sessions/suspend_resume_session.py diff --git a/python/samples/semantic-kernel-migration/README.md b/python/samples/semantic-kernel-migration/README.md index 6e6a135a0f..3a298fcf3d 100644 --- a/python/samples/semantic-kernel-migration/README.md +++ b/python/samples/semantic-kernel-migration/README.md @@ -9,12 +9,12 @@ This gallery helps Semantic Kernel (SK) developers move to the Microsoft Agent F ### Chat completion parity - [01_basic_chat_completion.py](chat_completion/01_basic_chat_completion.py) — Minimal SK `ChatCompletionAgent` and AF `Agent` conversation. - [02_chat_completion_with_tool.py](chat_completion/02_chat_completion_with_tool.py) — Adds a simple tool/function call in both SDKs. -- [03_chat_completion_thread_and_stream.py](chat_completion/03_chat_completion_thread_and_stream.py) — Demonstrates thread reuse and streaming prompts. +- [03_chat_completion_thread_and_stream.py](chat_completion/03_chat_completion_thread_and_stream.py) — Demonstrates session reuse and streaming prompts. ### Azure AI agent parity - [01_basic_azure_ai_agent.py](azure_ai_agent/01_basic_azure_ai_agent.py) — Create and run an Azure AI agent end to end. - [02_azure_ai_agent_with_code_interpreter.py](azure_ai_agent/02_azure_ai_agent_with_code_interpreter.py) — Enable hosted code interpreter/tool execution. -- [03_azure_ai_agent_threads_and_followups.py](azure_ai_agent/03_azure_ai_agent_threads_and_followups.py) — Persist threads and follow-ups across invocations. +- [03_azure_ai_agent_threads_and_followups.py](azure_ai_agent/03_azure_ai_agent_threads_and_followups.py) — Persist sessions and follow-ups across invocations. ### OpenAI Assistants API parity - [01_basic_openai_assistant.py](openai_assistant/01_basic_openai_assistant.py) — Baseline assistant comparison. @@ -70,6 +70,6 @@ Swap the script path for any other workflow or process sample. Deactivate the sa ## Tips for Migration - Keep the original SK sample open while iterating on the AF equivalent; the code is intentionally formatted so you can copy/paste across SDKs. -- Threads/conversation state are explicit in AF. When porting SK code that relies on implicit thread reuse, call `agent.create_session()` and pass it into each `run` call. +- Sessions/conversation state are explicit in AF. When porting SK code that relies on implicit session reuse, call `agent.create_session()` and pass it into each `run` call. - Tools map cleanly: SK `@kernel_function` plugins translate to AF `@tool` callables. Hosted tools (code interpreter, web search, MCP) are available only in AF—introduce them once parity is achieved. - For multi-agent orchestration, AF workflows expose checkpoints and resume capabilities that SK Process/Team abstractions do not. Use the workflow samples as a blueprint when modernizing complex agent graphs. From b75f88295b7590526d9467992271af0d5e54962b Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Feb 2026 16:27:57 +0100 Subject: [PATCH 18/28] Fix broken Redis README link to renamed sample --- python/packages/redis/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/packages/redis/README.md b/python/packages/redis/README.md index be36dc8c28..3517f460de 100644 --- a/python/packages/redis/README.md +++ b/python/packages/redis/README.md @@ -30,10 +30,10 @@ The `RedisChatMessageStore` provides persistent conversation storage using Redis #### Basic Usage Examples -See the complete [Redis chat message store examples](../../samples/02-agents/conversations/redis_chat_message_store_thread.py) including: +See the complete [Redis history provider examples](../../samples/02-agents/conversations/redis_history_provider.py) including: - User session management - Conversation persistence across restarts -- Thread serialization and deserialization +- Session serialization and deserialization - Automatic message trimming - Error handling patterns From 92744cc77c5acc709e69f06da40327da64736a80 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Feb 2026 16:56:27 +0100 Subject: [PATCH 19/28] Fix Mem0 OSS client search: pass scoping params as direct kwargs AsyncMemory (OSS) expects user_id/agent_id/run_id as direct kwargs, while AsyncMemoryClient (Platform) expects them in a filters dict. Adds tests for both client types. Port of fix from #3844 to new Mem0ContextProvider. --- .../agent_framework_mem0/_context_provider.py | 11 +++- .../mem0/tests/test_mem0_context_provider.py | 55 +++++++++++++++++++ 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/python/packages/mem0/agent_framework_mem0/_context_provider.py b/python/packages/mem0/agent_framework_mem0/_context_provider.py index c2d10d42cb..ce3140d4ef 100644 --- a/python/packages/mem0/agent_framework_mem0/_context_provider.py +++ b/python/packages/mem0/agent_framework_mem0/_context_provider.py @@ -108,9 +108,16 @@ async def before_run( filters = self._build_filters(session_id=context.session_id) + # AsyncMemory (OSS) expects user_id/agent_id/run_id as direct kwargs + # AsyncMemoryClient (Platform) expects them in a filters dict + search_kwargs: dict[str, Any] = {"query": input_text} + if isinstance(self.mem0_client, AsyncMemory): + search_kwargs.update(filters) + else: + search_kwargs["filters"] = filters + search_response: _MemorySearchResponse_v1_1 | _MemorySearchResponse_v2 = await self.mem0_client.search( # type: ignore[misc] - query=input_text, - filters=filters, + **search_kwargs, ) if isinstance(search_response, list): diff --git a/python/packages/mem0/tests/test_mem0_context_provider.py b/python/packages/mem0/tests/test_mem0_context_provider.py index 96a70c2beb..c13ac58dd4 100644 --- a/python/packages/mem0/tests/test_mem0_context_provider.py +++ b/python/packages/mem0/tests/test_mem0_context_provider.py @@ -26,6 +26,17 @@ def mock_mem0_client() -> AsyncMock: return mock_client +@pytest.fixture +def mock_oss_mem0_client() -> AsyncMock: + """Create a mock Mem0 OSS AsyncMemory client.""" + from mem0 import AsyncMemory + + mock_client = AsyncMock(spec=AsyncMemory) + mock_client.add = AsyncMock() + mock_client.search = AsyncMock() + return mock_client + + # -- Initialization tests ------------------------------------------------------ @@ -157,6 +168,50 @@ async def test_search_query_combines_input_messages(self, mock_mem0_client: Asyn call_kwargs = mock_mem0_client.search.call_args.kwargs assert call_kwargs["query"] == "Hello\nWorld" + async def test_oss_client_passes_direct_kwargs(self, mock_oss_mem0_client: AsyncMock) -> None: + """OSS AsyncMemory client should receive user_id as direct kwarg, not in filters.""" + mock_oss_mem0_client.search.return_value = [{"memory": "User likes Python"}] + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_oss_mem0_client, user_id="u1") + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[Message(role="user", text="Hello")], session_id="s1") + + await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + call_kwargs = mock_oss_mem0_client.search.call_args.kwargs + assert call_kwargs["query"] == "Hello" + assert call_kwargs["user_id"] == "u1" + assert "filters" not in call_kwargs + + async def test_oss_client_all_scoping_params(self, mock_oss_mem0_client: AsyncMock) -> None: + """OSS client with all scoping parameters passes them as direct kwargs.""" + mock_oss_mem0_client.search.return_value = [] + provider = Mem0ContextProvider( + source_id="mem0", mem0_client=mock_oss_mem0_client, user_id="u1", agent_id="a1", application_id="app1" + ) + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[Message(role="user", text="Hello")], session_id="s1") + + await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + call_kwargs = mock_oss_mem0_client.search.call_args.kwargs + assert call_kwargs["user_id"] == "u1" + assert call_kwargs["agent_id"] == "a1" + assert "filters" not in call_kwargs + + async def test_platform_client_passes_filters_dict(self, mock_mem0_client: AsyncMock) -> None: + """Platform AsyncMemoryClient should receive scoping params in a filters dict.""" + mock_mem0_client.search.return_value = [] + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") + session = AgentSession(session_id="test-session") + ctx = SessionContext(input_messages=[Message(role="user", text="Hello")], session_id="s1") + + await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + + call_kwargs = mock_mem0_client.search.call_args.kwargs + assert call_kwargs["query"] == "Hello" + assert "filters" in call_kwargs + assert call_kwargs["filters"]["user_id"] == "u1" + # -- after_run tests ----------------------------------------------------------- From 73eb64e84489d40389c5513c4713b01661e60853 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Feb 2026 22:52:55 +0100 Subject: [PATCH 20/28] Fix rebase issues: restore missing _conversation_state.py and checkpoint decode logic - Add back _conversation_state.py (encode/decode_chat_messages) lost in rebase - Fix on_checkpoint_restore to decode cache/conversation with decode_chat_messages - Fix on_checkpoint_restore to use decode_checkpoint_value for pending requests - Add tests/workflow/__init__.py for relative import support - Fix test_agent_executor checkpoint selection (checkpoints[1] not superstep) --- .../_workflows/_agent_executor.py | 26 +++++-- .../_workflows/_conversation_state.py | 75 +++++++++++++++++++ .../packages/core/tests/workflow/__init__.py | 0 .../tests/workflow/test_agent_executor.py | 5 ++ 4 files changed, 101 insertions(+), 5 deletions(-) create mode 100644 python/packages/core/agent_framework/_workflows/_conversation_state.py create mode 100644 python/packages/core/tests/workflow/__init__.py diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index 6252bf8ffe..74dbca3540 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -14,7 +14,7 @@ from .._sessions import AgentSession from .._types import AgentResponse, AgentResponseUpdate, Message from ._agent_utils import resolve_agent_id -from ._checkpoint_encoding import encode_checkpoint_value +from ._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value from ._const import WORKFLOW_RUN_KWARGS_KEY from ._conversation_state import encode_chat_messages from ._executor import Executor, handler @@ -245,11 +245,27 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: Args: state: Checkpoint data dict """ + from ._conversation_state import decode_chat_messages + cache_payload = state.get("cache") - self._cache = cache_payload or [] + if cache_payload: + try: + self._cache = decode_chat_messages(cache_payload) + except Exception as exc: + logger.warning("Failed to restore cache: %s", exc) + self._cache = [] + else: + self._cache = [] full_conversation_payload = state.get("full_conversation") - self._full_conversation = full_conversation_payload or [] + if full_conversation_payload: + try: + self._full_conversation = decode_chat_messages(full_conversation_payload) + except Exception as exc: + logger.warning("Failed to restore full conversation: %s", exc) + self._full_conversation = [] + else: + self._full_conversation = [] session_payload = state.get("agent_session") if session_payload: @@ -263,11 +279,11 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: pending_requests_payload = state.get("pending_agent_requests") if pending_requests_payload: - self._pending_agent_requests = pending_requests_payload + self._pending_agent_requests = decode_checkpoint_value(pending_requests_payload) pending_responses_payload = state.get("pending_responses_to_agent") if pending_responses_payload: - self._pending_responses_to_agent = pending_responses_payload + self._pending_responses_to_agent = decode_checkpoint_value(pending_responses_payload) def reset(self) -> None: """Reset the internal cache of the executor.""" diff --git a/python/packages/core/agent_framework/_workflows/_conversation_state.py b/python/packages/core/agent_framework/_workflows/_conversation_state.py new file mode 100644 index 0000000000..95945998df --- /dev/null +++ b/python/packages/core/agent_framework/_workflows/_conversation_state.py @@ -0,0 +1,75 @@ +# Copyright (c) Microsoft. All rights reserved. + +from collections.abc import Iterable +from typing import Any, cast + +from agent_framework import Message + +from ._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value + +"""Utilities for serializing and deserializing chat conversations for persistence. + +These helpers convert rich `Message` instances to checkpoint-friendly payloads +using the same encoding primitives as the workflow runner. This preserves +`additional_properties` and other metadata without relying on unsafe mechanisms +such as pickling. +""" + + +def encode_chat_messages(messages: Iterable[Message]) -> list[dict[str, Any]]: + """Serialize chat messages into checkpoint-safe payloads.""" + encoded: list[dict[str, Any]] = [] + for message in messages: + encoded.append({ + "role": encode_checkpoint_value(message.role), + "contents": [encode_checkpoint_value(content) for content in message.contents], + "author_name": message.author_name, + "message_id": message.message_id, + "additional_properties": { + key: encode_checkpoint_value(value) for key, value in message.additional_properties.items() + }, + }) + return encoded + + +def decode_chat_messages(payload: Iterable[dict[str, Any]]) -> list[Message]: + """Restore chat messages from checkpoint-safe payloads.""" + restored: list[Message] = [] + for item in payload: + if not isinstance(item, dict): + continue + + role_value = decode_checkpoint_value(item.get("role")) + if isinstance(role_value, str): + role = role_value + elif isinstance(role_value, dict) and "value" in role_value: + # Handle legacy serialization format + role = role_value["value"] + else: + role = "assistant" + + contents_field = item.get("contents", []) + contents: list[Any] = [] + if isinstance(contents_field, list): + contents_iter: list[Any] = contents_field # type: ignore[assignment] + for entry in contents_iter: + decoded_entry: Any = decode_checkpoint_value(entry) + contents.append(decoded_entry) + + additional_field = item.get("additional_properties", {}) + additional: dict[str, Any] = {} + if isinstance(additional_field, dict): + additional_dict = cast(dict[str, Any], additional_field) + for key, value in additional_dict.items(): + additional[key] = decode_checkpoint_value(value) + + restored.append( + Message( # type: ignore[call-overload] + role=role, + contents=contents, + author_name=item.get("author_name"), + message_id=item.get("message_id"), + additional_properties=additional, + ) + ) + return restored diff --git a/python/packages/core/tests/workflow/__init__.py b/python/packages/core/tests/workflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/packages/core/tests/workflow/test_agent_executor.py b/python/packages/core/tests/workflow/test_agent_executor.py index 4dadbdfb11..07b15d5bf1 100644 --- a/python/packages/core/tests/workflow/test_agent_executor.py +++ b/python/packages/core/tests/workflow/test_agent_executor.py @@ -89,6 +89,11 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: "and the second one is after the agent execution." ) + # Get the second checkpoint which should contain the state after processing + # the first message by the start executor in the sequential workflow + checkpoints.sort(key=lambda cp: cp.timestamp) + restore_checkpoint = checkpoints[1] + # Verify checkpoint contains executor state with both cache and session assert "_executor_state" in restore_checkpoint.state executor_states = restore_checkpoint.state["_executor_state"] From c9834487eff85dc182a860854d9c0df1ebc3edc2 Mon Sep 17 00:00:00 2001 From: Eduard van Valkenburg Date: Thu, 12 Feb 2026 13:41:24 +0100 Subject: [PATCH 21/28] Add STORES_BY_DEFAULT ClassVar to skip redundant InMemoryHistoryProvider injection Chat clients that store history server-side by default (OpenAI Responses API, Azure AI Agent) now declare STORES_BY_DEFAULT = True. The agent checks this during auto-injection and skips InMemoryHistoryProvider unless the user explicitly sets store=False. --- .../_context_provider.py | 109 ++++++------------ .../tests/test_aisearch_context_provider.py | 10 +- .../agent_framework_azure_ai/_chat_client.py | 1 + .../packages/core/agent_framework/_agents.py | 11 +- .../packages/core/agent_framework/_clients.py | 10 +- .../core/agent_framework/_settings.py | 60 ++++++---- .../openai/_responses_client.py | 4 +- .../packages/core/tests/core/test_agents.py | 50 ++++++++ .../packages/core/tests/core/test_settings.py | 92 +++++++++++++++ .../tests/workflow/test_workflow_kwargs.py | 2 +- .../agent_framework_durabletask/_shim.py | 6 +- .../_handoff.py | 6 +- .../orchestrations/tests/test_group_chat.py | 4 +- .../azure_ai_agents_with_shared_thread.py | 1 + 14 files changed, 252 insertions(+), 114 deletions(-) diff --git a/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py b/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py index 0b0a24768d..9edf73fdc3 100644 --- a/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py +++ b/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py @@ -10,13 +10,12 @@ import sys from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Any, ClassVar, Literal +from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Message from agent_framework._logging import get_logger -from agent_framework._pydantic import AFBaseSettings from agent_framework._sessions import AgentSession, BaseContextProvider, SessionContext -from agent_framework._settings import load_settings +from agent_framework._settings import SecretString, load_settings from agent_framework.exceptions import ServiceInitializationError from azure.core.credentials import AzureKeyCredential from azure.core.credentials_async import AsyncTokenCredential @@ -42,54 +41,6 @@ VectorizableTextQuery, VectorizedQuery, ) -from pydantic import SecretStr - - -class AzureAISearchSettings(AFBaseSettings): - """Settings for Azure AI Search Context Provider with auto-loading from environment. - - The settings are first loaded from environment variables with the prefix 'AZURE_SEARCH_'. - If the environment variables are not found, the settings can be loaded from a .env file. - - Keyword Args: - endpoint: Azure AI Search endpoint URL. - Can be set via environment variable AZURE_SEARCH_ENDPOINT. - index_name: Name of the search index. - Can be set via environment variable AZURE_SEARCH_INDEX_NAME. - knowledge_base_name: Name of an existing Knowledge Base (for agentic mode). - Can be set via environment variable AZURE_SEARCH_KNOWLEDGE_BASE_NAME. - api_key: API key for authentication (optional, use managed identity if not provided). - Can be set via environment variable AZURE_SEARCH_API_KEY. - env_file_path: If provided, the .env settings are read from this file path location. - env_file_encoding: The encoding of the .env file, defaults to 'utf-8'. - - Examples: - .. code-block:: python - - from agent_framework_aisearch import AzureAISearchSettings - - # Using environment variables - # Set AZURE_SEARCH_ENDPOINT=https://mysearch.search.windows.net - # Set AZURE_SEARCH_INDEX_NAME=my-index - settings = AzureAISearchSettings() - - # Or passing parameters directly - settings = AzureAISearchSettings( - endpoint="https://mysearch.search.windows.net", - index_name="my-index", - ) - - # Or loading from a .env file - settings = AzureAISearchSettings(env_file_path="path/to/.env") - """ - - env_prefix: ClassVar[str] = "AZURE_SEARCH_" - - endpoint: str | None = None - index_name: str | None = None - knowledge_base_name: str | None = None - api_key: SecretStr | None = None - if TYPE_CHECKING: from agent_framework._agents import SupportsAgentRun @@ -157,6 +108,29 @@ class AzureAISearchSettings(AFBaseSettings): _DEFAULT_AGENTIC_MESSAGE_HISTORY_COUNT = 10 +class AzureAISearchSettings(TypedDict, total=False): + """Settings for Azure AI Search Context Provider with auto-loading from environment. + + The settings are first loaded from environment variables with the prefix 'AZURE_SEARCH_'. + If the environment variables are not found, the settings can be loaded from a .env file. + + Keys: + endpoint: Azure AI Search endpoint URL. + Can be set via environment variable AZURE_SEARCH_ENDPOINT. + index_name: Name of the search index. + Can be set via environment variable AZURE_SEARCH_INDEX_NAME. + knowledge_base_name: Name of an existing Knowledge Base (for agentic mode). + Can be set via environment variable AZURE_SEARCH_KNOWLEDGE_BASE_NAME. + api_key: API key for authentication (optional, use managed identity if not provided). + Can be set via environment variable AZURE_SEARCH_API_KEY. + """ + + endpoint: str | None + index_name: str | None + knowledge_base_name: str | None + api_key: SecretString | None + + class AzureAISearchContextProvider(BaseContextProvider): """Azure AI Search context provider using the new BaseContextProvider hooks pattern. @@ -220,10 +194,18 @@ def __init__( """ super().__init__(source_id) + # Determine which fields are required based on mode + required: list[str | tuple[str, ...]] = ["endpoint"] + if mode == "semantic": + required.append("index_name") + elif mode == "agentic": + required.append(("index_name", "knowledge_base_name")) + # Load settings from environment/file settings = load_settings( AzureAISearchSettings, env_prefix="AZURE_SEARCH_", + required_fields=required, endpoint=endpoint, index_name=index_name, knowledge_base_name=knowledge_base_name, @@ -232,32 +214,11 @@ def __init__( env_file_encoding=env_file_encoding, ) - if not settings.get("endpoint"): + if mode == "agentic" and settings.get("index_name") and not model_deployment_name: raise ServiceInitializationError( - "Azure AI Search endpoint is required. Set via 'endpoint' parameter " - "or 'AZURE_SEARCH_ENDPOINT' environment variable." + "model_deployment_name is required for agentic mode when creating Knowledge Base from index." ) - if mode == "semantic": - if not settings.get("index_name"): - raise ServiceInitializationError( - "Azure AI Search index name is required for semantic mode. " - "Set via 'index_name' parameter or 'AZURE_SEARCH_INDEX_NAME' environment variable." - ) - elif mode == "agentic": - if settings.get("index_name") and settings.get("knowledge_base_name"): - raise ServiceInitializationError( - "For agentic mode, provide either 'index_name' OR 'knowledge_base_name', not both." - ) - if not settings.get("index_name") and not settings.get("knowledge_base_name"): - raise ServiceInitializationError( - "For agentic mode, provide either 'index_name' or 'knowledge_base_name'." - ) - if settings.get("index_name") and not model_deployment_name: - raise ServiceInitializationError( - "model_deployment_name is required for agentic mode when creating Knowledge Base from index." - ) - resolved_credential: AzureKeyCredential | AsyncTokenCredential if credential: resolved_credential = credential diff --git a/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py b/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py index 8c18617e6e..96ed975b54 100644 --- a/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py +++ b/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py @@ -7,7 +7,7 @@ import pytest from agent_framework import Message from agent_framework._sessions import AgentSession, SessionContext -from agent_framework.exceptions import ServiceInitializationError +from agent_framework.exceptions import ServiceInitializationError, SettingNotFoundError from agent_framework_azure_ai_search._context_provider import AzureAISearchContextProvider @@ -88,7 +88,7 @@ def test_source_id_set(self) -> None: assert provider.source_id == "my-source" def test_missing_endpoint_raises(self) -> None: - with patch.dict(os.environ, {}, clear=True), pytest.raises(ServiceInitializationError, match="endpoint"): + with patch.dict(os.environ, {}, clear=True), pytest.raises(SettingNotFoundError, match="endpoint"): AzureAISearchContextProvider( source_id="s", endpoint=None, @@ -97,7 +97,7 @@ def test_missing_endpoint_raises(self) -> None: ) def test_missing_index_name_semantic_raises(self) -> None: - with pytest.raises(ServiceInitializationError, match="index name"): + with pytest.raises(SettingNotFoundError, match="index_name"): AzureAISearchContextProvider( source_id="s", endpoint="https://test.search.windows.net", @@ -124,7 +124,7 @@ class TestInitAgenticValidation: """Initialization validation tests for agentic mode.""" def test_both_index_and_kb_raises(self) -> None: - with pytest.raises(ServiceInitializationError, match="not both"): + with pytest.raises(SettingNotFoundError, match="multiple were set"): AzureAISearchContextProvider( source_id="s", endpoint="https://test.search.windows.net", @@ -137,7 +137,7 @@ def test_both_index_and_kb_raises(self) -> None: ) def test_neither_index_nor_kb_raises(self) -> None: - with pytest.raises(ServiceInitializationError, match="provide either"): + with pytest.raises(SettingNotFoundError, match="none was set"): AzureAISearchContextProvider( source_id="s", endpoint="https://test.search.windows.net", diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py index 186cee6d1f..22d77c76b8 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py @@ -210,6 +210,7 @@ class AzureAIAgentClient( """Azure AI Agent Chat client with middleware, telemetry, and function invocation support.""" OTEL_PROVIDER_NAME: ClassVar[str] = "azure.ai" # type: ignore[reportIncompatibleVariableOverride, misc] + STORES_BY_DEFAULT: ClassVar[bool] = True # type: ignore[reportIncompatibleVariableOverride, misc] # region Hosted Tool Factory Methods diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index a84445678c..56a616471c 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -67,14 +67,9 @@ if TYPE_CHECKING: from ._types import ChatOptions - -ResponseModelT = TypeVar("ResponseModelT", bound=BaseModel | None, default=None, covariant=True) -ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel) - - logger = get_logger("agent_framework") -ThreadTypeT = TypeVar("ThreadTypeT", bound="AgentSession") +ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel) OptionsCoT = TypeVar( "OptionsCoT", bound=TypedDict, # type: ignore[valid-type] @@ -978,6 +973,10 @@ async def _prepare_run_context( and not session.service_session_id and not opts.get("conversation_id") and not opts.get("store") + and not ( + getattr(self.client, "STORES_BY_DEFAULT", False) + and opts.get("store") is not False + ) ): self.context_providers.append(InMemoryHistoryProvider("memory")) diff --git a/python/packages/core/agent_framework/_clients.py b/python/packages/core/agent_framework/_clients.py index f5abb1d999..9abdbb5697 100644 --- a/python/packages/core/agent_framework/_clients.py +++ b/python/packages/core/agent_framework/_clients.py @@ -262,7 +262,15 @@ async def _stream(): OTEL_PROVIDER_NAME: ClassVar[str] = "unknown" DEFAULT_EXCLUDE: ClassVar[set[str]] = {"additional_properties"} - # This is used for OTel setup, should be overridden in subclasses + STORES_BY_DEFAULT: ClassVar[bool] = False + """Whether this client stores conversation history server-side by default. + + Clients that use server-side storage (e.g., OpenAI Responses API with ``store=True`` + as default, Azure AI Agent threads) should override this to ``True``. + When ``True``, the agent skips auto-injecting ``InMemoryHistoryProvider`` unless the + user explicitly sets ``store=False``. + """ + # OTEL_PROVIDER_NAME is used for OTel setup, should be overridden in subclasses def __init__( self, diff --git a/python/packages/core/agent_framework/_settings.py b/python/packages/core/agent_framework/_settings.py index d378688d55..57919e7124 100644 --- a/python/packages/core/agent_framework/_settings.py +++ b/python/packages/core/agent_framework/_settings.py @@ -12,14 +12,17 @@ class MySettings(TypedDict, total=False): api_key: str | None # optional — resolves to None if not set model_id: str | None # optional by default + source_a: str | None + source_b: str | None - # Make model_id required at call time: + # Make model_id required; require exactly one of source_a / source_b: settings = load_settings( MySettings, env_prefix="MY_APP_", - required_fields=["model_id"], + required_fields=["model_id", ("source_a", "source_b")], model_id="gpt-4", + source_a="value", ) settings["api_key"] # type-checked dict access settings["model_id"] # str | None per type, but guaranteed not None at runtime @@ -167,7 +170,7 @@ def load_settings( env_prefix: str = "", env_file_path: str | None = None, env_file_encoding: str | None = None, - required_fields: Sequence[str] | None = None, + required_fields: Sequence[str | tuple[str, ...]] | None = None, **overrides: Any, ) -> SettingsT: """Load settings from environment variables, a ``.env`` file, and explicit overrides. @@ -181,18 +184,19 @@ def load_settings( 4. Default values — fields with class-level defaults on the TypedDict, or ``None`` for optional fields. - Fields listed in *required_fields* are validated after resolution. If any - required field resolves to ``None``, a ``SettingNotFoundError`` is raised. - This allows callers to decide which fields are required based on runtime - context (e.g. ``endpoint`` is only required when no pre-built client is - provided). + Entries in *required_fields* are validated after resolution: + + - A **string** entry means the field must resolve to a non-``None`` value. + - A **tuple** entry means exactly one field in the group must be non-``None`` + (mutually exclusive). Args: settings_type: A ``TypedDict`` class describing the settings schema. env_prefix: Prefix for environment variable lookup (e.g. ``"OPENAI_"``). env_file_path: Path to ``.env`` file. Defaults to ``".env"`` when omitted. env_file_encoding: Encoding for reading the ``.env`` file. Defaults to ``"utf-8"``. - required_fields: Field names that must resolve to a non-``None`` value. + required_fields: Field names (``str``) that must resolve to a non-``None`` + value, or tuples of field names where exactly one must be set. **overrides: Field values. ``None`` values are ignored so that callers can forward optional parameters without masking env-var / default resolution. @@ -200,7 +204,8 @@ def load_settings( A populated dict matching *settings_type*. Raises: - SettingNotFoundError: If a required field could not be resolved from any source. + SettingNotFoundError: If a required field could not be resolved from any + source, or if a mutually exclusive constraint is violated. ServiceInitializationError: If an override value has an incompatible type. """ encoding = env_file_encoding or "utf-8" @@ -215,7 +220,6 @@ def load_settings( # Get field type hints from the TypedDict hints = get_type_hints(settings_type) - required: set[str] = set(required_fields) if required_fields else set() result: dict[str, Any] = {} for field_name, field_type in hints.items(): @@ -249,14 +253,30 @@ def load_settings( result[field_name] = None # Validate required fields after all resolution - if required: - for field_name in required: - if result.get(field_name) is None: - env_var_name = f"{env_prefix}{field_name.upper()}" - raise SettingNotFoundError( - f"Required setting '{field_name}' was not provided. " - f"Set it via the '{field_name}' parameter or the " - f"'{env_var_name}' environment variable." - ) + if required_fields: + for entry in required_fields: + if isinstance(entry, str): + # Single required field + if result.get(entry) is None: + env_var_name = f"{env_prefix}{entry.upper()}" + raise SettingNotFoundError( + f"Required setting '{entry}' was not provided. " + f"Set it via the '{entry}' parameter or the " + f"'{env_var_name}' environment variable." + ) + else: + # Mutually exclusive group — exactly one must be set + set_fields = [f for f in entry if result.get(f) is not None] + if len(set_fields) == 0: + names = ", ".join(f"'{f}'" for f in entry) + raise SettingNotFoundError( + f"Exactly one of {names} must be provided, but none was set." + ) + if len(set_fields) > 1: + all_names = ", ".join(f"'{f}'" for f in entry) + set_names = ", ".join(f"'{f}'" for f in set_fields) + raise SettingNotFoundError( + f"Only one of {all_names} may be provided, but multiple were set: {set_names}." + ) return result # type: ignore[return-value] diff --git a/python/packages/core/agent_framework/openai/_responses_client.py b/python/packages/core/agent_framework/openai/_responses_client.py index 5f2b637f1c..5ab414dc85 100644 --- a/python/packages/core/agent_framework/openai/_responses_client.py +++ b/python/packages/core/agent_framework/openai/_responses_client.py @@ -13,7 +13,7 @@ ) from datetime import datetime, timezone from itertools import chain -from typing import TYPE_CHECKING, Any, Generic, Literal, NoReturn, TypedDict, cast +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, NoReturn, TypedDict, cast from openai import AsyncOpenAI, BadRequestError from openai.types.responses.file_search_tool_param import FileSearchToolParam @@ -238,6 +238,8 @@ class RawOpenAIResponsesClient( # type: ignore[misc] Use ``OpenAIResponsesClient`` instead for a fully-featured client with all layers applied. """ + STORES_BY_DEFAULT: ClassVar[bool] = True # type: ignore[reportIncompatibleVariableOverride, misc] + FILE_SEARCH_MAX_RESULTS: int = 50 # region Inner Methods diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index bfa7525d7e..16365a7a5f 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -933,4 +933,54 @@ async def before_run(self, *, agent, session, context, state): assert options.get("instructions") == "Context-provided instructions" +# region STORES_BY_DEFAULT tests + + +async def test_stores_by_default_skips_inmemory_injection(client: SupportsChatGetResponse) -> None: + """Client with STORES_BY_DEFAULT=True should not auto-inject InMemoryHistoryProvider.""" + from agent_framework._sessions import InMemoryHistoryProvider + + # Simulate a client that stores by default + client.STORES_BY_DEFAULT = True # type: ignore[attr-defined] + + agent = Agent(client=client) + session = agent.create_session() + + await agent.run("Hello", session=session) + + # No InMemoryHistoryProvider should have been injected + assert not any(isinstance(p, InMemoryHistoryProvider) for p in agent.context_providers) + + +async def test_stores_by_default_false_injects_inmemory(client: SupportsChatGetResponse) -> None: + """Client with STORES_BY_DEFAULT=False (default) should auto-inject InMemoryHistoryProvider.""" + from agent_framework._sessions import InMemoryHistoryProvider + + agent = Agent(client=client) + session = agent.create_session() + + await agent.run("Hello", session=session) + + # InMemoryHistoryProvider should have been injected + assert any(isinstance(p, InMemoryHistoryProvider) for p in agent.context_providers) + + +async def test_stores_by_default_with_store_false_injects_inmemory(client: SupportsChatGetResponse) -> None: + """Client with STORES_BY_DEFAULT=True but store=False should still inject InMemoryHistoryProvider.""" + from agent_framework._sessions import InMemoryHistoryProvider + + client.STORES_BY_DEFAULT = True # type: ignore[attr-defined] + + agent = Agent(client=client) + session = agent.create_session() + + await agent.run("Hello", session=session, options={"store": False}) + + # User explicitly disabled server storage, so InMemoryHistoryProvider should be injected + assert any(isinstance(p, InMemoryHistoryProvider) for p in agent.context_providers) + + +# endregion + + # endregion diff --git a/python/packages/core/tests/core/test_settings.py b/python/packages/core/tests/core/test_settings.py index 12ff683924..8ab60ca043 100644 --- a/python/packages/core/tests/core/test_settings.py +++ b/python/packages/core/tests/core/test_settings.py @@ -28,6 +28,12 @@ class SecretSettings(TypedDict, total=False): username: str | None +class ExclusiveSettings(TypedDict, total=False): + source_a: str | None + source_b: str | None + other: str | None + + class TestLoadSettingsBasic: """Test basic load_settings functionality.""" @@ -236,3 +242,89 @@ def test_str_accepted_for_secretstring(self) -> None: assert isinstance(settings["api_key"], SecretString) assert settings["api_key"] == "plain-string" + + +class TestMutuallyExclusive: + """Test mutually exclusive field validation via tuple entries in required_fields.""" + + def test_exactly_one_set_passes(self) -> None: + settings = load_settings( + ExclusiveSettings, + env_prefix="TEST_", + required_fields=[("source_a", "source_b")], + source_a="value-a", + ) + + assert settings["source_a"] == "value-a" + assert settings["source_b"] is None + + def test_none_set_raises(self) -> None: + from agent_framework.exceptions import SettingNotFoundError + + with pytest.raises(SettingNotFoundError, match="none was set"): + load_settings( + ExclusiveSettings, + env_prefix="TEST_", + required_fields=[("source_a", "source_b")], + ) + + def test_both_set_raises(self) -> None: + from agent_framework.exceptions import SettingNotFoundError + + with pytest.raises(SettingNotFoundError, match="multiple were set"): + load_settings( + ExclusiveSettings, + env_prefix="TEST_", + required_fields=[("source_a", "source_b")], + source_a="a", + source_b="b", + ) + + def test_env_var_counts_as_set(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("TEST_SOURCE_B", "env-b") + + settings = load_settings( + ExclusiveSettings, + env_prefix="TEST_", + required_fields=[("source_a", "source_b")], + ) + + assert settings["source_b"] == "env-b" + + def test_env_var_and_override_both_set_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: + from agent_framework.exceptions import SettingNotFoundError + + monkeypatch.setenv("TEST_SOURCE_B", "env-b") + + with pytest.raises(SettingNotFoundError, match="multiple were set"): + load_settings( + ExclusiveSettings, + env_prefix="TEST_", + required_fields=[("source_a", "source_b")], + source_a="a", + ) + + def test_other_fields_unaffected(self) -> None: + settings = load_settings( + ExclusiveSettings, + env_prefix="TEST_", + required_fields=[("source_a", "source_b")], + source_a="a", + other="extra", + ) + + assert settings["source_a"] == "a" + assert settings["other"] == "extra" + + def test_mixed_required_and_exclusive(self) -> None: + settings = load_settings( + ExclusiveSettings, + env_prefix="TEST_", + required_fields=["other", ("source_a", "source_b")], + source_b="b", + other="required-val", + ) + + assert settings["other"] == "required-val" + assert settings["source_b"] == "b" + assert settings["source_a"] is None diff --git a/python/packages/core/tests/workflow/test_workflow_kwargs.py b/python/packages/core/tests/workflow/test_workflow_kwargs.py index 5a71afafe4..36aece0bb3 100644 --- a/python/packages/core/tests/workflow/test_workflow_kwargs.py +++ b/python/packages/core/tests/workflow/test_workflow_kwargs.py @@ -88,7 +88,7 @@ def run( messages: str | Message | Sequence[str | Message] | None = None, *, stream: bool = False, - thread: AgentThread | None = None, + session: AgentSession | None = None, options: dict[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: diff --git a/python/packages/durabletask/agent_framework_durabletask/_shim.py b/python/packages/durabletask/agent_framework_durabletask/_shim.py index 10352d8bb7..1f6165133c 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_shim.py +++ b/python/packages/durabletask/agent_framework_durabletask/_shim.py @@ -136,7 +136,11 @@ def create_session(self, **kwargs: Any) -> DurableAgentSession: """Create a new agent session via the provider.""" return self._executor.get_new_session(self.name, **kwargs) - def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession: + def get_session(self, **kwargs: Any) -> AgentSession: + """Retrieve an existing session via the provider. + + For durable agents, sessions do not use `service_session_id` so this is not used. + """ return self._executor.get_new_session(self.name, **kwargs) def _normalize_messages(self, messages: str | Message | list[str] | list[Message] | None) -> str: diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py index 3bc07cd67c..ea5f8bd201 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py @@ -196,7 +196,7 @@ def __init__( agent: SupportsAgentRun, handoffs: Sequence[HandoffConfiguration], *, - agent_thread: AgentSession | None = None, + agent_session: AgentSession | None = None, is_start_agent: bool = False, termination_condition: TerminationCondition | None = None, autonomous_mode: bool = False, @@ -208,7 +208,7 @@ def __init__( Args: agent: The agent to execute handoffs: Sequence of handoff configurations defining target agents - agent_thread: Optional AgentSession that manages the agent's execution context + agent_session: Optional AgentSession that manages the agent's execution context is_start_agent: Whether this agent is the starting agent in the handoff workflow. There can only be one starting agent in a handoff workflow. termination_condition: Optional callable that determines when to terminate the workflow @@ -222,7 +222,7 @@ def __init__( autonomous_mode_turn_limit: Maximum number of autonomous turns before requesting user input. """ cloned_agent = self._prepare_agent_with_handoffs(agent, handoffs) - super().__init__(cloned_agent, session=agent_thread) + super().__init__(cloned_agent, session=agent_session) self._handoff_targets = {handoff.target_id for handoff in handoffs} self._termination_condition = termination_condition diff --git a/python/packages/orchestrations/tests/test_group_chat.py b/python/packages/orchestrations/tests/test_group_chat.py index 59bae72314..7115e9c7d7 100644 --- a/python/packages/orchestrations/tests/test_group_chat.py +++ b/python/packages/orchestrations/tests/test_group_chat.py @@ -132,7 +132,7 @@ async def run( self, messages: str | Message | Sequence[str | Message] | None = None, *, - thread: AgentThread | None = None, + session: AgentSession | None = None, **kwargs: Any, ) -> AgentResponse: if self._call_count == 0: @@ -346,7 +346,7 @@ def __init__(self) -> None: super().__init__(name="", description="test") def run( - self, messages: Any = None, *, stream: bool = False, thread: Any = None, **kwargs: Any + self, messages: Any = None, *, stream: bool = False, session: Any = None, **kwargs: Any ) -> AgentResponse | AsyncIterable[AgentResponseUpdate]: if stream: diff --git a/python/samples/03-workflows/agents/azure_ai_agents_with_shared_thread.py b/python/samples/03-workflows/agents/azure_ai_agents_with_shared_thread.py index f33aa2ef10..988d3f539f 100644 --- a/python/samples/03-workflows/agents/azure_ai_agents_with_shared_thread.py +++ b/python/samples/03-workflows/agents/azure_ai_agents_with_shared_thread.py @@ -7,6 +7,7 @@ AgentExecutor, AgentExecutorRequest, AgentExecutorResponse, + InMemoryHistoryProvider, WorkflowBuilder, WorkflowContext, WorkflowRunState, From 0437ffdf86790381d8e33856181d6d7ebfc3e319 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Thu, 12 Feb 2026 19:41:37 +0100 Subject: [PATCH 22/28] Fix broken markdown links in azure_ai and redis READMEs --- python/packages/redis/README.md | 2 +- python/samples/02-agents/providers/azure_ai/README.md | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/packages/redis/README.md b/python/packages/redis/README.md index 3517f460de..43ab34c1ee 100644 --- a/python/packages/redis/README.md +++ b/python/packages/redis/README.md @@ -30,7 +30,7 @@ The `RedisChatMessageStore` provides persistent conversation storage using Redis #### Basic Usage Examples -See the complete [Redis history provider examples](../../samples/02-agents/conversations/redis_history_provider.py) including: +See the complete [Redis history provider examples](../../samples/02-agents/conversations/redis_chat_message_store_thread.py) including: - User session management - Conversation persistence across restarts - Session serialization and deserialization diff --git a/python/samples/02-agents/providers/azure_ai/README.md b/python/samples/02-agents/providers/azure_ai/README.md index 075219462c..a047ccb9b0 100644 --- a/python/samples/02-agents/providers/azure_ai/README.md +++ b/python/samples/02-agents/providers/azure_ai/README.md @@ -28,8 +28,8 @@ This folder contains examples demonstrating different ways to create and use age | [`azure_ai_with_local_mcp.py`](azure_ai_with_local_mcp.py) | Shows how to integrate local Model Context Protocol (MCP) tools with Azure AI agents. | | [`azure_ai_with_response_format.py`](azure_ai_with_response_format.py) | Shows how to use structured outputs (response format) with Azure AI agents using Pydantic models to enforce specific response schemas. | | [`azure_ai_with_runtime_json_schema.py`](azure_ai_with_runtime_json_schema.py) | Shows how to use structured outputs (response format) with Azure AI agents using a JSON schema to enforce specific response schemas. | -| [`azure_ai_with_search_context_agentic.py`](../../sessions/azure_ai_search/azure_ai_with_search_context_agentic.py) | Shows how to use AzureAISearchContextProvider with agentic mode. Uses Knowledge Bases for multi-hop reasoning across documents with query planning. Recommended for most scenarios - slightly slower with more token consumption for query planning, but more accurate results. | -| [`azure_ai_with_search_context_semantic.py`](../../sessions/azure_ai_search/azure_ai_with_search_context_semantic.py) | Shows how to use AzureAISearchContextProvider with semantic mode. Fast hybrid search with vector + keyword search and semantic ranking for RAG. Best for simple queries where speed is critical. | +| [`azure_ai_with_search_context_agentic.py`](../../context_providers/azure_ai_search/azure_ai_with_search_context_agentic.py) | Shows how to use AzureAISearchContextProvider with agentic mode. Uses Knowledge Bases for multi-hop reasoning across documents with query planning. Recommended for most scenarios - slightly slower with more token consumption for query planning, but more accurate results. | +| [`azure_ai_with_search_context_semantic.py`](../../context_providers/azure_ai_search/azure_ai_with_search_context_semantic.py) | Shows how to use AzureAISearchContextProvider with semantic mode. Fast hybrid search with vector + keyword search and semantic ranking for RAG. Best for simple queries where speed is critical. | | [`azure_ai_with_sharepoint.py`](azure_ai_with_sharepoint.py) | Shows how to use SharePoint grounding with Azure AI agents to search through SharePoint content and answer user questions with proper citations. Requires a SharePoint connection configured in your Azure AI project. | | [`azure_ai_with_thread.py`](azure_ai_with_thread.py) | Demonstrates session management with Azure AI agents, including automatic session creation for stateless conversations and explicit session management for maintaining conversation context across multiple interactions. | | [`azure_ai_with_image_generation.py`](azure_ai_with_image_generation.py) | Shows how to use `AzureAIClient.get_image_generation_tool()` with Azure AI agents to generate images based on text prompts. | From fcd2cb0697935d16d6a7b26a8ee4db96f1805e70 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Thu, 12 Feb 2026 19:52:53 +0100 Subject: [PATCH 23/28] Fix getting-started samples to use session API instead of removed thread/ContextProvider API --- .../samples/01-get-started/03_multi_turn.py | 8 ++-- python/samples/01-get-started/04_memory.py | 44 +++++++++++-------- 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/python/samples/01-get-started/03_multi_turn.py b/python/samples/01-get-started/03_multi_turn.py index f69930619e..4f7d7dacbe 100644 --- a/python/samples/01-get-started/03_multi_turn.py +++ b/python/samples/01-get-started/03_multi_turn.py @@ -34,15 +34,15 @@ async def main() -> None: # # - # Create a thread to maintain conversation history - thread = agent.get_new_thread() + # Create a session to maintain conversation history + session = agent.create_session() # First turn - result = await agent.run("My name is Alice and I love hiking.", thread=thread) + result = await agent.run("My name is Alice and I love hiking.", session=session) print(f"Agent: {result}\n") # Second turn — the agent should remember the user's name and hobby - result = await agent.run("What do you remember about me?", thread=thread) + result = await agent.run("What do you remember about me?", session=session) print(f"Agent: {result}") # diff --git a/python/samples/01-get-started/04_memory.py b/python/samples/01-get-started/04_memory.py index 08320a6e43..d35ef21581 100644 --- a/python/samples/01-get-started/04_memory.py +++ b/python/samples/01-get-started/04_memory.py @@ -2,10 +2,9 @@ import asyncio import os -from collections.abc import MutableSequence from typing import Any -from agent_framework import Context, ContextProvider, Message +from agent_framework._sessions import AgentSession, BaseContextProvider, SessionContext from agent_framework.azure import AzureOpenAIResponsesClient from azure.identity import AzureCliCredential @@ -23,28 +22,37 @@ # -class UserNameProvider(ContextProvider): +class UserNameProvider(BaseContextProvider): """A simple context provider that remembers the user's name.""" def __init__(self) -> None: + super().__init__(source_id="user-name-provider") self.user_name: str | None = None - async def invoking(self, messages: Message | MutableSequence[Message], **kwargs: Any) -> Context: + async def before_run( + self, + *, + agent: Any, + session: AgentSession, + context: SessionContext, + state: dict[str, Any], + ) -> None: """Called before each agent invocation — add extra instructions.""" if self.user_name: - return Context(instructions=f"The user's name is {self.user_name}. Always address them by name.") - return Context(instructions="You don't know the user's name yet. Ask for it politely.") + context.instructions.append(f"The user's name is {self.user_name}. Always address them by name.") + else: + context.instructions.append("You don't know the user's name yet. Ask for it politely.") - async def invoked( + async def after_run( self, - request_messages: Message | list[Message] | None = None, - response_messages: "Message | list[Message] | None" = None, - invoke_exception: Exception | None = None, - **kwargs: Any, + *, + agent: Any, + session: AgentSession, + context: SessionContext, + state: dict[str, Any], ) -> None: """Called after each agent invocation — extract information.""" - msgs = [request_messages] if isinstance(request_messages, Message) else list(request_messages or []) - for msg in msgs: + for msg in context.input_messages: text = msg.text if hasattr(msg, "text") else "" if isinstance(text, str) and "my name is" in text.lower(): # Simple extraction — production code should use structured extraction @@ -66,22 +74,22 @@ async def main() -> None: agent = client.as_agent( name="MemoryAgent", instructions="You are a friendly assistant.", - context_provider=memory, + context_providers=[memory], ) # - thread = agent.get_new_thread() + session = agent.create_session() # The provider doesn't know the user yet — it will ask for a name - result = await agent.run("Hello! What's the square root of 9?", thread=thread) + result = await agent.run("Hello! What's the square root of 9?", session=session) print(f"Agent: {result}\n") # Now provide the name — the provider extracts and stores it - result = await agent.run("My name is Alice", thread=thread) + result = await agent.run("My name is Alice", session=session) print(f"Agent: {result}\n") # Subsequent calls are personalized - result = await agent.run("What is 2 + 2?", thread=thread) + result = await agent.run("What is 2 + 2?", session=session) print(f"Agent: {result}\n") print(f"[Memory] Stored user name: {memory.user_name}") From 9b4df38e5e1b47643699166e5c3e60d2993d4d67 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Thu, 12 Feb 2026 20:30:49 +0100 Subject: [PATCH 24/28] updates to workflow as agent --- .../packages/core/agent_framework/_agents.py | 47 +++-- python/packages/core/agent_framework/_mcp.py | 75 +++++--- .../core/agent_framework/_settings.py | 4 +- .../core/agent_framework/_workflows/_agent.py | 104 ++++++++--- .../_workflows/_agent_executor.py | 20 +-- .../_workflows/_conversation_state.py | 75 -------- .../azure/_responses_client.py | 15 +- python/packages/core/tests/core/test_mcp.py | 6 +- python/uv.lock | 162 ++++++++++-------- 9 files changed, 250 insertions(+), 258 deletions(-) delete mode 100644 python/packages/core/agent_framework/_workflows/_conversation_state.py diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 56a616471c..b6dbca8099 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -408,6 +408,27 @@ def get_session(self, *, service_session_id: str, session_id: str | None = None, """ return AgentSession(session_id=session_id, service_session_id=service_session_id) + async def _run_after_providers( + self, + *, + session: AgentSession | None, + context: SessionContext, + ) -> None: + """Run after_run on all context providers in reverse order. + + Keyword Args: + session: The conversation session. + context: The invocation context with response populated. + """ + state = session.state if session else {} + for provider in reversed(self.context_providers): + await provider.after_run( + agent=self, # type: ignore[arg-type] + session=session, # type: ignore[arg-type] + context=context, + state=state, + ) + def as_tool( self, *, @@ -973,10 +994,7 @@ async def _prepare_run_context( and not session.service_session_id and not opts.get("conversation_id") and not opts.get("store") - and not ( - getattr(self.client, "STORES_BY_DEFAULT", False) - and opts.get("store") is not False - ) + and not (getattr(self.client, "STORES_BY_DEFAULT", False) and opts.get("store") is not False) ): self.context_providers.append(InMemoryHistoryProvider("memory")) @@ -1082,27 +1100,6 @@ async def _finalize_response( # Run after_run providers (reverse order) await self._run_after_providers(session=session, context=session_context) - async def _run_after_providers( - self, - *, - session: AgentSession | None, - context: SessionContext, - ) -> None: - """Run after_run on all context providers in reverse order. - - Keyword Args: - session: The conversation session. - context: The invocation context with response populated. - """ - state = session.state if session else {} - for provider in reversed(self.context_providers): - await provider.after_run( - agent=self, # type: ignore[arg-type] - session=session, # type: ignore[arg-type] - context=context, - state=state, - ) - async def _prepare_session_and_messages( self, *, diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index 64ff60fa7f..61c8620bf7 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -102,21 +102,31 @@ def _parse_prompt_result_from_mcp( if isinstance(content, types.TextContent): parts.append(content.text) elif isinstance(content, (types.ImageContent, types.AudioContent)): - parts.append(json.dumps({ - "type": "image" if isinstance(content, types.ImageContent) else "audio", - "data": content.data, - "mimeType": content.mimeType, - }, default=str)) + parts.append( + json.dumps( + { + "type": "image" if isinstance(content, types.ImageContent) else "audio", + "data": content.data, + "mimeType": content.mimeType, + }, + default=str, + ) + ) elif isinstance(content, types.EmbeddedResource): match content.resource: case types.TextResourceContents(): parts.append(content.resource.text) case types.BlobResourceContents(): - parts.append(json.dumps({ - "type": "blob", - "data": content.resource.blob, - "mimeType": content.resource.mimeType, - }, default=str)) + parts.append( + json.dumps( + { + "type": "blob", + "data": content.resource.blob, + "mimeType": content.resource.mimeType, + }, + default=str, + ) + ) else: parts.append(str(content)) if not parts: @@ -159,27 +169,42 @@ def _parse_tool_result_from_mcp( case types.TextContent(): parts.append(item.text) case types.ImageContent() | types.AudioContent(): - parts.append(json.dumps({ - "type": "image" if isinstance(item, types.ImageContent) else "audio", - "data": item.data, - "mimeType": item.mimeType, - }, default=str)) + parts.append( + json.dumps( + { + "type": "image" if isinstance(item, types.ImageContent) else "audio", + "data": item.data, + "mimeType": item.mimeType, + }, + default=str, + ) + ) case types.ResourceLink(): - parts.append(json.dumps({ - "type": "resource_link", - "uri": str(item.uri), - "mimeType": item.mimeType, - }, default=str)) + parts.append( + json.dumps( + { + "type": "resource_link", + "uri": str(item.uri), + "mimeType": item.mimeType, + }, + default=str, + ) + ) case types.EmbeddedResource(): match item.resource: case types.TextResourceContents(): parts.append(item.resource.text) case types.BlobResourceContents(): - parts.append(json.dumps({ - "type": "blob", - "data": item.resource.blob, - "mimeType": item.resource.mimeType, - }, default=str)) + parts.append( + json.dumps( + { + "type": "blob", + "data": item.resource.blob, + "mimeType": item.resource.mimeType, + }, + default=str, + ) + ) case _: parts.append(str(item)) if not parts: diff --git a/python/packages/core/agent_framework/_settings.py b/python/packages/core/agent_framework/_settings.py index 57919e7124..30fd8c8508 100644 --- a/python/packages/core/agent_framework/_settings.py +++ b/python/packages/core/agent_framework/_settings.py @@ -269,9 +269,7 @@ def load_settings( set_fields = [f for f in entry if result.get(f) is not None] if len(set_fields) == 0: names = ", ".join(f"'{f}'" for f in entry) - raise SettingNotFoundError( - f"Exactly one of {names} must be provided, but none was set." - ) + raise SettingNotFoundError(f"Exactly one of {names} must be provided, but none was set.") if len(set_fields) > 1: all_names = ", ".join(f"'{f}'" for f in entry) set_names = ", ".join(f"'{f}'" for f in set_fields) diff --git a/python/packages/core/agent_framework/_workflows/_agent.py b/python/packages/core/agent_framework/_workflows/_agent.py index dfa078a9fa..ef2b127d45 100644 --- a/python/packages/core/agent_framework/_workflows/_agent.py +++ b/python/packages/core/agent_framework/_workflows/_agent.py @@ -6,23 +6,22 @@ import logging import sys import uuid -from collections.abc import AsyncIterable, Awaitable +from collections.abc import AsyncIterable, Awaitable, Sequence from dataclasses import dataclass from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast, overload -from agent_framework import ( +from .._agents import BaseAgent +from .._sessions import AgentSession, BaseContextProvider, BaseHistoryProvider, SessionContext +from .._types import ( AgentResponse, AgentResponseUpdate, - AgentSession, - BaseAgent, Content, Message, ResponseStream, UsageDetails, + add_usage_details, ) - -from .._types import add_usage_details from ..exceptions import AgentExecutionException from ._checkpoint import CheckpointStorage from ._events import ( @@ -80,6 +79,7 @@ def __init__( id: str | None = None, name: str | None = None, description: str | None = None, + context_providers: Sequence[BaseContextProvider] | None = None, **kwargs: Any, ) -> None: """Initialize the WorkflowAgent. @@ -91,6 +91,7 @@ def __init__( id: Unique identifier for the agent. If None, will be generated. name: Optional name for the agent. description: Optional description of the agent. + context_providers: Optional sequence of context providers for the agent. **kwargs: Additional keyword arguments passed to BaseAgent. Note: @@ -111,7 +112,7 @@ def __init__( if not any(is_type_compatible(list[Message], input_type) for input_type in start_executor.input_types): raise ValueError("Workflow's start executor cannot handle list[Message]") - super().__init__(id=id, name=name, description=description, **kwargs) + super().__init__(id=id, name=name, description=description, context_providers=context_providers, **kwargs) self._workflow: Workflow = workflow self._pending_requests: dict[str, WorkflowEvent[Any]] = {} @@ -128,7 +129,7 @@ def pending_requests(self) -> dict[str, WorkflowEvent[Any]]: @overload def run( self, - messages: str | Message | list[str] | list[Message] | None = None, + messages: str | Message | Sequence[str | Message] | None = None, *, stream: Literal[True], session: AgentSession | None = None, @@ -140,7 +141,7 @@ def run( @overload async def run( self, - messages: str | Message | list[str] | list[Message] | None = None, + messages: str | Message | Sequence[str | Message] | None = None, *, stream: Literal[False] = ..., session: AgentSession | None = None, @@ -151,7 +152,7 @@ async def run( def run( self, - messages: str | Message | list[str] | list[Message] | None = None, + messages: str | Message | Sequence[str | Message] | None = None, *, stream: bool = False, session: AgentSession | None = None, @@ -185,23 +186,19 @@ def run( or AgentResponseUpdate objects. Request info events (type='request_info') will be converted to function call and approval request contents. """ - input_messages = normalize_messages_input(messages) + if messages is None: + messages = [] response_id = str(uuid.uuid4()) if stream: return ResponseStream( - self._run_stream_impl( - input_messages, response_id, session, checkpoint_id, checkpoint_storage, **kwargs - ), + self._run_stream_impl(messages, response_id, session, checkpoint_id, checkpoint_storage, **kwargs), finalizer=AgentResponse.from_updates, ) - input_messages = normalize_messages_input(messages) - response_id = str(uuid.uuid4()) - - return self._run_impl(input_messages, response_id, session, checkpoint_id, checkpoint_storage, **kwargs) + return self._run_impl(messages, response_id, session, checkpoint_id, checkpoint_storage, **kwargs) async def _run_impl( self, - input_messages: list[Message], + messages: str | Message | Sequence[str | Message], response_id: str, session: AgentSession | None, checkpoint_id: str | None = None, @@ -211,7 +208,7 @@ async def _run_impl( """Internal implementation of non-streaming execution. Args: - input_messages: Normalized input messages to process. + messages: Normalized input messages to process. response_id: The unique response ID for this workflow execution. session: The agent session for conversation context. checkpoint_id: ID of checkpoint to restore from. @@ -222,16 +219,42 @@ async def _run_impl( Returns: An AgentResponse representing the workflow execution results. """ + input_messages = normalize_messages_input(messages) + + # run the context providers with the session + session_context = SessionContext( + session_id=session.session_id if session else None, + service_session_id=session.service_session_id if session else None, + input_messages=input_messages or [], + options={}, + ) + state = session.state if session else {} + for provider in self.context_providers: + if isinstance(provider, BaseHistoryProvider) and not provider.load_messages: + continue + await provider.before_run( + agent=self, # type: ignore[arg-type] + session=session, # type: ignore[arg-type] + context=session_context, + state=state, + ) + # combine the messages + session_messages: list[Message] = session_context.get_messages(include_input=True) + output_events: list[WorkflowEvent[Any]] = [] - async for event in self._run_core(input_messages, checkpoint_id, checkpoint_storage, streaming=False, **kwargs): + async for event in self._run_core( + session_messages, checkpoint_id, checkpoint_storage, streaming=False, **kwargs + ): if event.type == "output" or event.type == "request_info": output_events.append(event) - return self._convert_workflow_events_to_agent_response(response_id, output_events) + result = self._convert_workflow_events_to_agent_response(response_id, output_events) + await self._run_after_providers(session=session, context=session_context) + return result async def _run_stream_impl( self, - input_messages: list[Message], + messages: str | Message | Sequence[str | Message], response_id: str, session: AgentSession | None, checkpoint_id: str | None = None, @@ -241,7 +264,7 @@ async def _run_stream_impl( """Internal implementation of streaming execution. Args: - input_messages: Normalized input messages to process. + messages: Input messages to process. response_id: The unique response ID for this workflow execution. session: The agent session for conversation context. checkpoint_id: ID of checkpoint to restore from. @@ -252,14 +275,39 @@ async def _run_stream_impl( Yields: AgentResponseUpdate objects representing the workflow execution progress. """ - async for event in self._run_core(input_messages, checkpoint_id, checkpoint_storage, streaming=True, **kwargs): + input_messages = normalize_messages_input(messages) + + # run the context providers with the session + session_context = SessionContext( + session_id=session.session_id if session else None, + service_session_id=session.service_session_id if session else None, + input_messages=input_messages or [], + options={}, + ) + state = session.state if session else {} + for provider in self.context_providers: + if isinstance(provider, BaseHistoryProvider) and not provider.load_messages: + continue + await provider.before_run( + agent=self, # type: ignore[arg-type] + session=session, # type: ignore[arg-type] + context=session_context, + state=state, + ) + # combine the messages + + session_messages: list[Message] = session_context.get_messages(include_input=True) + async for event in self._run_core( + session_messages, checkpoint_id, checkpoint_storage, streaming=True, **kwargs + ): updates = self._convert_workflow_event_to_agent_response_updates(response_id, event) for update in updates: yield update + await self._run_after_providers(session=session, context=session_context) async def _run_core( self, - input_messages: list[Message], + input_messages: Sequence[Message], checkpoint_id: str | None, checkpoint_storage: CheckpointStorage | None, streaming: bool, @@ -327,7 +375,7 @@ async def _run_core( # endregion Run Methods - def _process_pending_requests(self, input_messages: list[Message]) -> dict[str, Any]: + def _process_pending_requests(self, input_messages: Sequence[Message]) -> dict[str, Any]: """Process pending requests by extracting function responses and updating state. Args: @@ -584,7 +632,7 @@ def _convert_workflow_event_to_agent_response_updates( # Ignore workflow-internal events return [] - def _extract_function_responses(self, input_messages: list[Message]) -> dict[str, Any]: + def _extract_function_responses(self, input_messages: Sequence[Message]) -> dict[str, Any]: """Extract function responses from input messages.""" function_responses: dict[str, Any] = {} for message in input_messages: diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index 74dbca3540..3f0bd63086 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -14,9 +14,7 @@ from .._sessions import AgentSession from .._types import AgentResponse, AgentResponseUpdate, Message from ._agent_utils import resolve_agent_id -from ._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value from ._const import WORKFLOW_RUN_KWARGS_KEY -from ._conversation_state import encode_chat_messages from ._executor import Executor, handler from ._message_utils import normalize_messages_input from ._request_info_mixin import response_handler @@ -231,11 +229,11 @@ async def on_checkpoint_save(self) -> dict[str, Any]: serialized_session = self._session.to_dict() return { - "cache": encode_chat_messages(self._cache), - "full_conversation": encode_chat_messages(self._full_conversation), + "cache": self._cache, + "full_conversation": self._full_conversation, "agent_session": serialized_session, - "pending_agent_requests": encode_checkpoint_value(self._pending_agent_requests), - "pending_responses_to_agent": encode_checkpoint_value(self._pending_responses_to_agent), + "pending_agent_requests": self._pending_agent_requests, + "pending_responses_to_agent": self._pending_responses_to_agent, } @override @@ -245,12 +243,10 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: Args: state: Checkpoint data dict """ - from ._conversation_state import decode_chat_messages - cache_payload = state.get("cache") if cache_payload: try: - self._cache = decode_chat_messages(cache_payload) + self._cache = cache_payload except Exception as exc: logger.warning("Failed to restore cache: %s", exc) self._cache = [] @@ -260,7 +256,7 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: full_conversation_payload = state.get("full_conversation") if full_conversation_payload: try: - self._full_conversation = decode_chat_messages(full_conversation_payload) + self._full_conversation = full_conversation_payload except Exception as exc: logger.warning("Failed to restore full conversation: %s", exc) self._full_conversation = [] @@ -279,11 +275,11 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: pending_requests_payload = state.get("pending_agent_requests") if pending_requests_payload: - self._pending_agent_requests = decode_checkpoint_value(pending_requests_payload) + self._pending_agent_requests = pending_requests_payload pending_responses_payload = state.get("pending_responses_to_agent") if pending_responses_payload: - self._pending_responses_to_agent = decode_checkpoint_value(pending_responses_payload) + self._pending_responses_to_agent = pending_responses_payload def reset(self) -> None: """Reset the internal cache of the executor.""" diff --git a/python/packages/core/agent_framework/_workflows/_conversation_state.py b/python/packages/core/agent_framework/_workflows/_conversation_state.py deleted file mode 100644 index 95945998df..0000000000 --- a/python/packages/core/agent_framework/_workflows/_conversation_state.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -from collections.abc import Iterable -from typing import Any, cast - -from agent_framework import Message - -from ._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value - -"""Utilities for serializing and deserializing chat conversations for persistence. - -These helpers convert rich `Message` instances to checkpoint-friendly payloads -using the same encoding primitives as the workflow runner. This preserves -`additional_properties` and other metadata without relying on unsafe mechanisms -such as pickling. -""" - - -def encode_chat_messages(messages: Iterable[Message]) -> list[dict[str, Any]]: - """Serialize chat messages into checkpoint-safe payloads.""" - encoded: list[dict[str, Any]] = [] - for message in messages: - encoded.append({ - "role": encode_checkpoint_value(message.role), - "contents": [encode_checkpoint_value(content) for content in message.contents], - "author_name": message.author_name, - "message_id": message.message_id, - "additional_properties": { - key: encode_checkpoint_value(value) for key, value in message.additional_properties.items() - }, - }) - return encoded - - -def decode_chat_messages(payload: Iterable[dict[str, Any]]) -> list[Message]: - """Restore chat messages from checkpoint-safe payloads.""" - restored: list[Message] = [] - for item in payload: - if not isinstance(item, dict): - continue - - role_value = decode_checkpoint_value(item.get("role")) - if isinstance(role_value, str): - role = role_value - elif isinstance(role_value, dict) and "value" in role_value: - # Handle legacy serialization format - role = role_value["value"] - else: - role = "assistant" - - contents_field = item.get("contents", []) - contents: list[Any] = [] - if isinstance(contents_field, list): - contents_iter: list[Any] = contents_field # type: ignore[assignment] - for entry in contents_iter: - decoded_entry: Any = decode_checkpoint_value(entry) - contents.append(decoded_entry) - - additional_field = item.get("additional_properties", {}) - additional: dict[str, Any] = {} - if isinstance(additional_field, dict): - additional_dict = cast(dict[str, Any], additional_field) - for key, value in additional_dict.items(): - additional[key] = decode_checkpoint_value(value) - - restored.append( - Message( # type: ignore[call-overload] - role=role, - contents=contents, - author_name=item.get("author_name"), - message_id=item.get("message_id"), - additional_properties=additional, - ) - ) - return restored diff --git a/python/packages/core/agent_framework/azure/_responses_client.py b/python/packages/core/agent_framework/azure/_responses_client.py index 65335482fe..0a6c0cd8c8 100644 --- a/python/packages/core/agent_framework/azure/_responses_client.py +++ b/python/packages/core/agent_framework/azure/_responses_client.py @@ -83,8 +83,7 @@ def __init__( env_file_encoding: str | None = None, instruction_role: str | None = None, middleware: Sequence[MiddlewareTypes] | None = None, - function_invocation_configuration: FunctionInvocationConfiguration - | None = None, + function_invocation_configuration: FunctionInvocationConfiguration | None = None, **kwargs: Any, ) -> None: """Initialize an Azure OpenAI Responses client. @@ -190,9 +189,7 @@ class MyOptions(AzureOpenAIResponsesOptions, total=False): deployment_name = str(model_id) # Project client path: create OpenAI client from an Azure AI Foundry project - if async_client is None and ( - project_client is not None or project_endpoint is not None - ): + if async_client is None and (project_client is not None or project_endpoint is not None): async_client = self._create_client_from_project( project_client=project_client, project_endpoint=project_endpoint, @@ -221,9 +218,7 @@ class MyOptions(AzureOpenAIResponsesOptions, total=False): and (hostname := urlparse(str(azure_openai_settings["endpoint"])).hostname) and hostname.endswith(".openai.azure.com") ): - azure_openai_settings["base_url"] = urljoin( - str(azure_openai_settings["endpoint"]), "/openai/v1/" - ) + azure_openai_settings["base_url"] = urljoin(str(azure_openai_settings["endpoint"]), "/openai/v1/") if not azure_openai_settings["responses_deployment_name"]: raise ServiceInitializationError( @@ -236,9 +231,7 @@ class MyOptions(AzureOpenAIResponsesOptions, total=False): endpoint=azure_openai_settings["endpoint"], base_url=azure_openai_settings["base_url"], api_version=azure_openai_settings["api_version"], # type: ignore - api_key=azure_openai_settings["api_key"].get_secret_value() - if azure_openai_settings["api_key"] - else None, + api_key=azure_openai_settings["api_key"].get_secret_value() if azure_openai_settings["api_key"] else None, ad_token=ad_token, ad_token_provider=ad_token_provider, token_endpoint=azure_openai_settings["token_endpoint"], diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index 21ff396a52..06376137d6 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -25,8 +25,8 @@ _get_input_model_from_mcp_tool, _normalize_mcp_name, _parse_content_from_mcp, - _parse_tool_result_from_mcp, _parse_message_from_mcp, + _parse_tool_result_from_mcp, _prepare_content_for_mcp, _prepare_message_for_mcp, logger, @@ -97,9 +97,7 @@ def test_parse_tool_result_from_mcp(): def test_parse_tool_result_from_mcp_single_text(): """Test conversion from MCP tool result with a single text item.""" - mcp_result = types.CallToolResult( - content=[types.TextContent(type="text", text="Simple result")] - ) + mcp_result = types.CallToolResult(content=[types.TextContent(type="text", text="Simple result")]) result = _parse_tool_result_from_mcp(mcp_result) # Single text item returns just the text diff --git a/python/uv.lock b/python/uv.lock index bab915d16b..a3e622c754 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -1853,7 +1853,7 @@ wheels = [ [[package]] name = "fastapi" -version = "0.128.8" +version = "0.129.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "annotated-doc", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -1862,9 +1862,9 @@ dependencies = [ { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "typing-inspection", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/01/72/0df5c58c954742f31a7054e2dd1143bae0b408b7f36b59b85f928f9b456c/fastapi-0.128.8.tar.gz", hash = "sha256:3171f9f328c4a218f0a8d2ba8310ac3a55d1ee12c28c949650288aee25966007", size = 375523, upload-time = "2026-02-11T15:19:36.69Z" } +sdist = { url = "https://files.pythonhosted.org/packages/48/47/75f6bea02e797abff1bca968d5997793898032d9923c1935ae2efdece642/fastapi-0.129.0.tar.gz", hash = "sha256:61315cebd2e65df5f97ec298c888f9de30430dd0612d59d6480beafbc10655af", size = 375450, upload-time = "2026-02-12T13:54:52.541Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/9f/37/37b07e276f8923c69a5df266bfcb5bac4ba8b55dfe4a126720f8c48681d1/fastapi-0.128.8-py3-none-any.whl", hash = "sha256:5618f492d0fe973a778f8fec97723f598aa9deee495040a8d51aaf3cf123ecf1", size = 103630, upload-time = "2026-02-11T15:19:35.209Z" }, + { url = "https://files.pythonhosted.org/packages/9e/dd/d0ee25348ac58245ee9f90b6f3cbb666bf01f69be7e0911f9851bddbda16/fastapi-0.129.0-py3-none-any.whl", hash = "sha256:b4946880e48f462692b31c083be0432275cbfb6e2274566b1be91479cc1a84ec", size = 102950, upload-time = "2026-02-12T13:54:54.528Z" }, ] [[package]] @@ -1947,11 +1947,11 @@ wheels = [ [[package]] name = "filelock" -version = "3.20.3" +version = "3.21.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/1d/65/ce7f1b70157833bf3cb851b556a37d4547ceafc158aa9b34b36782f23696/filelock-3.20.3.tar.gz", hash = "sha256:18c57ee915c7ec61cff0ecf7f0f869936c7c30191bb0cf406f1341778d0834e1", size = 19485, upload-time = "2026-01-09T17:55:05.421Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c0/6b/cc63cdbff46eba1ce2fbd058e9699f99c43f7e604da15413ca0331040bff/filelock-3.21.0.tar.gz", hash = "sha256:48c739c73c6fcacd381ed532226991150947c4a76dcd674f84d6807fd55dbaf2", size = 31341, upload-time = "2026-02-12T15:40:48.544Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b5/36/7fb70f04bf00bc646cd5bb45aa9eddb15e19437a28b8fb2b4a5249fac770/filelock-3.20.3-py3-none-any.whl", hash = "sha256:4b0dda527ee31078689fc205ec4f1c1bf7d56cf88b6dc9426c4f230e46c2dce1", size = 16701, upload-time = "2026-01-09T17:55:04.334Z" }, + { url = "https://files.pythonhosted.org/packages/da/ab/05190b5a64101fcb743bc63a034c0fac86a515c27c303c69221093565f28/filelock-3.21.0-py3-none-any.whl", hash = "sha256:0f90eee4c62101243df3007db3cf8fc3ebf1bb13541d3e72c687d6e0f3f7d531", size = 21381, upload-time = "2026-02-12T15:40:46.964Z" }, ] [[package]] @@ -2980,75 +2980,87 @@ wheels = [ [[package]] name = "librt" -version = "0.7.8" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e7/24/5f3646ff414285e0f7708fa4e946b9bf538345a41d1c375c439467721a5e/librt-0.7.8.tar.gz", hash = "sha256:1a4ede613941d9c3470b0368be851df6bb78ab218635512d0370b27a277a0862", size = 148323, upload-time = "2026-01-14T12:56:16.876Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/44/13/57b06758a13550c5f09563893b004f98e9537ee6ec67b7df85c3571c8832/librt-0.7.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b45306a1fc5f53c9330fbee134d8b3227fe5da2ab09813b892790400aa49352d", size = 56521, upload-time = "2026-01-14T12:54:40.066Z" }, - { url = "https://files.pythonhosted.org/packages/c2/24/bbea34d1452a10612fb45ac8356f95351ba40c2517e429602160a49d1fd0/librt-0.7.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:864c4b7083eeee250ed55135d2127b260d7eb4b5e953a9e5df09c852e327961b", size = 58456, upload-time = "2026-01-14T12:54:41.471Z" }, - { url = "https://files.pythonhosted.org/packages/04/72/a168808f92253ec3a810beb1eceebc465701197dbc7e865a1c9ceb3c22c7/librt-0.7.8-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:6938cc2de153bc927ed8d71c7d2f2ae01b4e96359126c602721340eb7ce1a92d", size = 164392, upload-time = "2026-01-14T12:54:42.843Z" }, - { url = "https://files.pythonhosted.org/packages/14/5c/4c0d406f1b02735c2e7af8ff1ff03a6577b1369b91aa934a9fa2cc42c7ce/librt-0.7.8-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:66daa6ac5de4288a5bbfbe55b4caa7bf0cd26b3269c7a476ffe8ce45f837f87d", size = 172959, upload-time = "2026-01-14T12:54:44.602Z" }, - { url = "https://files.pythonhosted.org/packages/82/5f/3e85351c523f73ad8d938989e9a58c7f59fb9c17f761b9981b43f0025ce7/librt-0.7.8-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4864045f49dc9c974dadb942ac56a74cd0479a2aafa51ce272c490a82322ea3c", size = 186717, upload-time = "2026-01-14T12:54:45.986Z" }, - { url = "https://files.pythonhosted.org/packages/08/f8/18bfe092e402d00fe00d33aa1e01dda1bd583ca100b393b4373847eade6d/librt-0.7.8-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a36515b1328dc5b3ffce79fe204985ca8572525452eacabee2166f44bb387b2c", size = 184585, upload-time = "2026-01-14T12:54:47.139Z" }, - { url = "https://files.pythonhosted.org/packages/4e/fc/f43972ff56fd790a9fa55028a52ccea1875100edbb856b705bd393b601e3/librt-0.7.8-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:b7e7f140c5169798f90b80d6e607ed2ba5059784968a004107c88ad61fb3641d", size = 180497, upload-time = "2026-01-14T12:54:48.946Z" }, - { url = "https://files.pythonhosted.org/packages/e1/3a/25e36030315a410d3ad0b7d0f19f5f188e88d1613d7d3fd8150523ea1093/librt-0.7.8-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ff71447cb778a4f772ddc4ce360e6ba9c95527ed84a52096bd1bbf9fee2ec7c0", size = 200052, upload-time = "2026-01-14T12:54:50.382Z" }, - { url = "https://files.pythonhosted.org/packages/fc/b8/f3a5a1931ae2a6ad92bf6893b9ef44325b88641d58723529e2c2935e8abe/librt-0.7.8-cp310-cp310-win32.whl", hash = "sha256:047164e5f68b7a8ebdf9fae91a3c2161d3192418aadd61ddd3a86a56cbe3dc85", size = 43477, upload-time = "2026-01-14T12:54:51.815Z" }, - { url = "https://files.pythonhosted.org/packages/fe/91/c4202779366bc19f871b4ad25db10fcfa1e313c7893feb942f32668e8597/librt-0.7.8-cp310-cp310-win_amd64.whl", hash = "sha256:d6f254d096d84156a46a84861183c183d30734e52383602443292644d895047c", size = 49806, upload-time = "2026-01-14T12:54:53.149Z" }, - { url = "https://files.pythonhosted.org/packages/1b/a3/87ea9c1049f2c781177496ebee29430e4631f439b8553a4969c88747d5d8/librt-0.7.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ff3e9c11aa260c31493d4b3197d1e28dd07768594a4f92bec4506849d736248f", size = 56507, upload-time = "2026-01-14T12:54:54.156Z" }, - { url = "https://files.pythonhosted.org/packages/5e/4a/23bcef149f37f771ad30203d561fcfd45b02bc54947b91f7a9ac34815747/librt-0.7.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ddb52499d0b3ed4aa88746aaf6f36a08314677d5c346234c3987ddc506404eac", size = 58455, upload-time = "2026-01-14T12:54:55.978Z" }, - { url = "https://files.pythonhosted.org/packages/22/6e/46eb9b85c1b9761e0f42b6e6311e1cc544843ac897457062b9d5d0b21df4/librt-0.7.8-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:e9c0afebbe6ce177ae8edba0c7c4d626f2a0fc12c33bb993d163817c41a7a05c", size = 164956, upload-time = "2026-01-14T12:54:57.311Z" }, - { url = "https://files.pythonhosted.org/packages/7a/3f/aa7c7f6829fb83989feb7ba9aa11c662b34b4bd4bd5b262f2876ba3db58d/librt-0.7.8-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:631599598e2c76ded400c0a8722dec09217c89ff64dc54b060f598ed68e7d2a8", size = 174364, upload-time = "2026-01-14T12:54:59.089Z" }, - { url = "https://files.pythonhosted.org/packages/3f/2d/d57d154b40b11f2cb851c4df0d4c4456bacd9b1ccc4ecb593ddec56c1a8b/librt-0.7.8-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9c1ba843ae20db09b9d5c80475376168feb2640ce91cd9906414f23cc267a1ff", size = 188034, upload-time = "2026-01-14T12:55:00.141Z" }, - { url = "https://files.pythonhosted.org/packages/59/f9/36c4dad00925c16cd69d744b87f7001792691857d3b79187e7a673e812fb/librt-0.7.8-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b5b007bb22ea4b255d3ee39dfd06d12534de2fcc3438567d9f48cdaf67ae1ae3", size = 186295, upload-time = "2026-01-14T12:55:01.303Z" }, - { url = "https://files.pythonhosted.org/packages/23/9b/8a9889d3df5efb67695a67785028ccd58e661c3018237b73ad081691d0cb/librt-0.7.8-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:dbd79caaf77a3f590cbe32dc2447f718772d6eea59656a7dcb9311161b10fa75", size = 181470, upload-time = "2026-01-14T12:55:02.492Z" }, - { url = "https://files.pythonhosted.org/packages/43/64/54d6ef11afca01fef8af78c230726a9394759f2addfbf7afc5e3cc032a45/librt-0.7.8-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:87808a8d1e0bd62a01cafc41f0fd6818b5a5d0ca0d8a55326a81643cdda8f873", size = 201713, upload-time = "2026-01-14T12:55:03.919Z" }, - { url = "https://files.pythonhosted.org/packages/2d/29/73e7ed2991330b28919387656f54109139b49e19cd72902f466bd44415fd/librt-0.7.8-cp311-cp311-win32.whl", hash = "sha256:31724b93baa91512bd0a376e7cf0b59d8b631ee17923b1218a65456fa9bda2e7", size = 43803, upload-time = "2026-01-14T12:55:04.996Z" }, - { url = "https://files.pythonhosted.org/packages/3f/de/66766ff48ed02b4d78deea30392ae200bcbd99ae61ba2418b49fd50a4831/librt-0.7.8-cp311-cp311-win_amd64.whl", hash = "sha256:978e8b5f13e52cf23a9e80f3286d7546baa70bc4ef35b51d97a709d0b28e537c", size = 50080, upload-time = "2026-01-14T12:55:06.489Z" }, - { url = "https://files.pythonhosted.org/packages/6f/e3/33450438ff3a8c581d4ed7f798a70b07c3206d298cf0b87d3806e72e3ed8/librt-0.7.8-cp311-cp311-win_arm64.whl", hash = "sha256:20e3946863d872f7cabf7f77c6c9d370b8b3d74333d3a32471c50d3a86c0a232", size = 43383, upload-time = "2026-01-14T12:55:07.49Z" }, - { url = "https://files.pythonhosted.org/packages/56/04/79d8fcb43cae376c7adbab7b2b9f65e48432c9eced62ac96703bcc16e09b/librt-0.7.8-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:9b6943885b2d49c48d0cff23b16be830ba46b0152d98f62de49e735c6e655a63", size = 57472, upload-time = "2026-01-14T12:55:08.528Z" }, - { url = "https://files.pythonhosted.org/packages/b4/ba/60b96e93043d3d659da91752689023a73981336446ae82078cddf706249e/librt-0.7.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:46ef1f4b9b6cc364b11eea0ecc0897314447a66029ee1e55859acb3dd8757c93", size = 58986, upload-time = "2026-01-14T12:55:09.466Z" }, - { url = "https://files.pythonhosted.org/packages/7c/26/5215e4cdcc26e7be7eee21955a7e13cbf1f6d7d7311461a6014544596fac/librt-0.7.8-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:907ad09cfab21e3c86e8f1f87858f7049d1097f77196959c033612f532b4e592", size = 168422, upload-time = "2026-01-14T12:55:10.499Z" }, - { url = "https://files.pythonhosted.org/packages/0f/84/e8d1bc86fa0159bfc24f3d798d92cafd3897e84c7fea7fe61b3220915d76/librt-0.7.8-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2991b6c3775383752b3ca0204842743256f3ad3deeb1d0adc227d56b78a9a850", size = 177478, upload-time = "2026-01-14T12:55:11.577Z" }, - { url = "https://files.pythonhosted.org/packages/57/11/d0268c4b94717a18aa91df1100e767b010f87b7ae444dafaa5a2d80f33a6/librt-0.7.8-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:03679b9856932b8c8f674e87aa3c55ea11c9274301f76ae8dc4d281bda55cf62", size = 192439, upload-time = "2026-01-14T12:55:12.7Z" }, - { url = "https://files.pythonhosted.org/packages/8d/56/1e8e833b95fe684f80f8894ae4d8b7d36acc9203e60478fcae599120a975/librt-0.7.8-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3968762fec1b2ad34ce57458b6de25dbb4142713e9ca6279a0d352fa4e9f452b", size = 191483, upload-time = "2026-01-14T12:55:13.838Z" }, - { url = "https://files.pythonhosted.org/packages/17/48/f11cf28a2cb6c31f282009e2208312aa84a5ee2732859f7856ee306176d5/librt-0.7.8-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:bb7a7807523a31f03061288cc4ffc065d684c39db7644c676b47d89553c0d714", size = 185376, upload-time = "2026-01-14T12:55:15.017Z" }, - { url = "https://files.pythonhosted.org/packages/b8/6a/d7c116c6da561b9155b184354a60a3d5cdbf08fc7f3678d09c95679d13d9/librt-0.7.8-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ad64a14b1e56e702e19b24aae108f18ad1bf7777f3af5fcd39f87d0c5a814449", size = 206234, upload-time = "2026-01-14T12:55:16.571Z" }, - { url = "https://files.pythonhosted.org/packages/61/de/1975200bb0285fc921c5981d9978ce6ce11ae6d797df815add94a5a848a3/librt-0.7.8-cp312-cp312-win32.whl", hash = "sha256:0241a6ed65e6666236ea78203a73d800dbed896cf12ae25d026d75dc1fcd1dac", size = 44057, upload-time = "2026-01-14T12:55:18.077Z" }, - { url = "https://files.pythonhosted.org/packages/8e/cd/724f2d0b3461426730d4877754b65d39f06a41ac9d0a92d5c6840f72b9ae/librt-0.7.8-cp312-cp312-win_amd64.whl", hash = "sha256:6db5faf064b5bab9675c32a873436b31e01d66ca6984c6f7f92621656033a708", size = 50293, upload-time = "2026-01-14T12:55:19.179Z" }, - { url = "https://files.pythonhosted.org/packages/bd/cf/7e899acd9ee5727ad8160fdcc9994954e79fab371c66535c60e13b968ffc/librt-0.7.8-cp312-cp312-win_arm64.whl", hash = "sha256:57175aa93f804d2c08d2edb7213e09276bd49097611aefc37e3fa38d1fb99ad0", size = 43574, upload-time = "2026-01-14T12:55:20.185Z" }, - { url = "https://files.pythonhosted.org/packages/a1/fe/b1f9de2829cf7fc7649c1dcd202cfd873837c5cc2fc9e526b0e7f716c3d2/librt-0.7.8-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4c3995abbbb60b3c129490fa985dfe6cac11d88fc3c36eeb4fb1449efbbb04fc", size = 57500, upload-time = "2026-01-14T12:55:21.219Z" }, - { url = "https://files.pythonhosted.org/packages/eb/d4/4a60fbe2e53b825f5d9a77325071d61cd8af8506255067bf0c8527530745/librt-0.7.8-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:44e0c2cbc9bebd074cf2cdbe472ca185e824be4e74b1c63a8e934cea674bebf2", size = 59019, upload-time = "2026-01-14T12:55:22.256Z" }, - { url = "https://files.pythonhosted.org/packages/6a/37/61ff80341ba5159afa524445f2d984c30e2821f31f7c73cf166dcafa5564/librt-0.7.8-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:4d2f1e492cae964b3463a03dc77a7fe8742f7855d7258c7643f0ee32b6651dd3", size = 169015, upload-time = "2026-01-14T12:55:23.24Z" }, - { url = "https://files.pythonhosted.org/packages/1c/86/13d4f2d6a93f181ebf2fc953868826653ede494559da8268023fe567fca3/librt-0.7.8-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:451e7ffcef8f785831fdb791bd69211f47e95dc4c6ddff68e589058806f044c6", size = 178161, upload-time = "2026-01-14T12:55:24.826Z" }, - { url = "https://files.pythonhosted.org/packages/88/26/e24ef01305954fc4d771f1f09f3dd682f9eb610e1bec188ffb719374d26e/librt-0.7.8-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3469e1af9f1380e093ae06bedcbdd11e407ac0b303a56bbe9afb1d6824d4982d", size = 193015, upload-time = "2026-01-14T12:55:26.04Z" }, - { url = "https://files.pythonhosted.org/packages/88/a0/92b6bd060e720d7a31ed474d046a69bd55334ec05e9c446d228c4b806ae3/librt-0.7.8-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f11b300027ce19a34f6d24ebb0a25fd0e24a9d53353225a5c1e6cadbf2916b2e", size = 192038, upload-time = "2026-01-14T12:55:27.208Z" }, - { url = "https://files.pythonhosted.org/packages/06/bb/6f4c650253704279c3a214dad188101d1b5ea23be0606628bc6739456624/librt-0.7.8-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:4adc73614f0d3c97874f02f2c7fd2a27854e7e24ad532ea6b965459c5b757eca", size = 186006, upload-time = "2026-01-14T12:55:28.594Z" }, - { url = "https://files.pythonhosted.org/packages/dc/00/1c409618248d43240cadf45f3efb866837fa77e9a12a71481912135eb481/librt-0.7.8-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:60c299e555f87e4c01b2eca085dfccda1dde87f5a604bb45c2906b8305819a93", size = 206888, upload-time = "2026-01-14T12:55:30.214Z" }, - { url = "https://files.pythonhosted.org/packages/d9/83/b2cfe8e76ff5c1c77f8a53da3d5de62d04b5ebf7cf913e37f8bca43b5d07/librt-0.7.8-cp313-cp313-win32.whl", hash = "sha256:b09c52ed43a461994716082ee7d87618096851319bf695d57ec123f2ab708951", size = 44126, upload-time = "2026-01-14T12:55:31.44Z" }, - { url = "https://files.pythonhosted.org/packages/a9/0b/c59d45de56a51bd2d3a401fc63449c0ac163e4ef7f523ea8b0c0dee86ec5/librt-0.7.8-cp313-cp313-win_amd64.whl", hash = "sha256:f8f4a901a3fa28969d6e4519deceab56c55a09d691ea7b12ca830e2fa3461e34", size = 50262, upload-time = "2026-01-14T12:55:33.01Z" }, - { url = "https://files.pythonhosted.org/packages/fc/b9/973455cec0a1ec592395250c474164c4a58ebf3e0651ee920fef1a2623f1/librt-0.7.8-cp313-cp313-win_arm64.whl", hash = "sha256:43d4e71b50763fcdcf64725ac680d8cfa1706c928b844794a7aa0fa9ac8e5f09", size = 43600, upload-time = "2026-01-14T12:55:34.054Z" }, - { url = "https://files.pythonhosted.org/packages/1a/73/fa8814c6ce2d49c3827829cadaa1589b0bf4391660bd4510899393a23ebc/librt-0.7.8-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:be927c3c94c74b05128089a955fba86501c3b544d1d300282cc1b4bd370cb418", size = 57049, upload-time = "2026-01-14T12:55:35.056Z" }, - { url = "https://files.pythonhosted.org/packages/53/fe/f6c70956da23ea235fd2e3cc16f4f0b4ebdfd72252b02d1164dd58b4e6c3/librt-0.7.8-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:7b0803e9008c62a7ef79058233db7ff6f37a9933b8f2573c05b07ddafa226611", size = 58689, upload-time = "2026-01-14T12:55:36.078Z" }, - { url = "https://files.pythonhosted.org/packages/1f/4d/7a2481444ac5fba63050d9abe823e6bc16896f575bfc9c1e5068d516cdce/librt-0.7.8-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:79feb4d00b2a4e0e05c9c56df707934f41fcb5fe53fd9efb7549068d0495b758", size = 166808, upload-time = "2026-01-14T12:55:37.595Z" }, - { url = "https://files.pythonhosted.org/packages/ac/3c/10901d9e18639f8953f57c8986796cfbf4c1c514844a41c9197cf87cb707/librt-0.7.8-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b9122094e3f24aa759c38f46bd8863433820654927370250f460ae75488b66ea", size = 175614, upload-time = "2026-01-14T12:55:38.756Z" }, - { url = "https://files.pythonhosted.org/packages/db/01/5cbdde0951a5090a80e5ba44e6357d375048123c572a23eecfb9326993a7/librt-0.7.8-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7e03bea66af33c95ce3addf87a9bf1fcad8d33e757bc479957ddbc0e4f7207ac", size = 189955, upload-time = "2026-01-14T12:55:39.939Z" }, - { url = "https://files.pythonhosted.org/packages/6a/b4/e80528d2f4b7eaf1d437fcbd6fc6ba4cbeb3e2a0cb9ed5a79f47c7318706/librt-0.7.8-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:f1ade7f31675db00b514b98f9ab9a7698c7282dad4be7492589109471852d398", size = 189370, upload-time = "2026-01-14T12:55:41.057Z" }, - { url = "https://files.pythonhosted.org/packages/c1/ab/938368f8ce31a9787ecd4becb1e795954782e4312095daf8fd22420227c8/librt-0.7.8-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:a14229ac62adcf1b90a15992f1ab9c69ae8b99ffb23cb64a90878a6e8a2f5b81", size = 183224, upload-time = "2026-01-14T12:55:42.328Z" }, - { url = "https://files.pythonhosted.org/packages/3c/10/559c310e7a6e4014ac44867d359ef8238465fb499e7eb31b6bfe3e3f86f5/librt-0.7.8-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:5bcaaf624fd24e6a0cb14beac37677f90793a96864c67c064a91458611446e83", size = 203541, upload-time = "2026-01-14T12:55:43.501Z" }, - { url = "https://files.pythonhosted.org/packages/f8/db/a0db7acdb6290c215f343835c6efda5b491bb05c3ddc675af558f50fdba3/librt-0.7.8-cp314-cp314-win32.whl", hash = "sha256:7aa7d5457b6c542ecaed79cec4ad98534373c9757383973e638ccced0f11f46d", size = 40657, upload-time = "2026-01-14T12:55:44.668Z" }, - { url = "https://files.pythonhosted.org/packages/72/e0/4f9bdc2a98a798511e81edcd6b54fe82767a715e05d1921115ac70717f6f/librt-0.7.8-cp314-cp314-win_amd64.whl", hash = "sha256:3d1322800771bee4a91f3b4bd4e49abc7d35e65166821086e5afd1e6c0d9be44", size = 46835, upload-time = "2026-01-14T12:55:45.655Z" }, - { url = "https://files.pythonhosted.org/packages/f9/3d/59c6402e3dec2719655a41ad027a7371f8e2334aa794ed11533ad5f34969/librt-0.7.8-cp314-cp314-win_arm64.whl", hash = "sha256:5363427bc6a8c3b1719f8f3845ea53553d301382928a86e8fab7984426949bce", size = 39885, upload-time = "2026-01-14T12:55:47.138Z" }, - { url = "https://files.pythonhosted.org/packages/4e/9c/2481d80950b83085fb14ba3c595db56330d21bbc7d88a19f20165f3538db/librt-0.7.8-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:ca916919793a77e4a98d4a1701e345d337ce53be4a16620f063191f7322ac80f", size = 59161, upload-time = "2026-01-14T12:55:48.45Z" }, - { url = "https://files.pythonhosted.org/packages/96/79/108df2cfc4e672336765d54e3ff887294c1cc36ea4335c73588875775527/librt-0.7.8-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:54feb7b4f2f6706bb82325e836a01be805770443e2400f706e824e91f6441dde", size = 61008, upload-time = "2026-01-14T12:55:49.527Z" }, - { url = "https://files.pythonhosted.org/packages/46/f2/30179898f9994a5637459d6e169b6abdc982012c0a4b2d4c26f50c06f911/librt-0.7.8-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:39a4c76fee41007070f872b648cc2f711f9abf9a13d0c7162478043377b52c8e", size = 187199, upload-time = "2026-01-14T12:55:50.587Z" }, - { url = "https://files.pythonhosted.org/packages/b4/da/f7563db55cebdc884f518ba3791ad033becc25ff68eb70902b1747dc0d70/librt-0.7.8-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ac9c8a458245c7de80bc1b9765b177055efff5803f08e548dd4bb9ab9a8d789b", size = 198317, upload-time = "2026-01-14T12:55:51.991Z" }, - { url = "https://files.pythonhosted.org/packages/b3/6c/4289acf076ad371471fa86718c30ae353e690d3de6167f7db36f429272f1/librt-0.7.8-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:95b67aa7eff150f075fda09d11f6bfb26edffd300f6ab1666759547581e8f666", size = 210334, upload-time = "2026-01-14T12:55:53.682Z" }, - { url = "https://files.pythonhosted.org/packages/4a/7f/377521ac25b78ac0a5ff44127a0360ee6d5ddd3ce7327949876a30533daa/librt-0.7.8-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:535929b6eff670c593c34ff435d5440c3096f20fa72d63444608a5aef64dd581", size = 211031, upload-time = "2026-01-14T12:55:54.827Z" }, - { url = "https://files.pythonhosted.org/packages/c5/b1/e1e96c3e20b23d00cf90f4aad48f0deb4cdfec2f0ed8380d0d85acf98bbf/librt-0.7.8-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:63937bd0f4d1cb56653dc7ae900d6c52c41f0015e25aaf9902481ee79943b33a", size = 204581, upload-time = "2026-01-14T12:55:56.811Z" }, - { url = "https://files.pythonhosted.org/packages/43/71/0f5d010e92ed9747e14bef35e91b6580533510f1e36a8a09eb79ee70b2f0/librt-0.7.8-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:cf243da9e42d914036fd362ac3fa77d80a41cadcd11ad789b1b5eec4daaf67ca", size = 224731, upload-time = "2026-01-14T12:55:58.175Z" }, - { url = "https://files.pythonhosted.org/packages/22/f0/07fb6ab5c39a4ca9af3e37554f9d42f25c464829254d72e4ebbd81da351c/librt-0.7.8-cp314-cp314t-win32.whl", hash = "sha256:171ca3a0a06c643bd0a2f62a8944e1902c94aa8e5da4db1ea9a8daf872685365", size = 41173, upload-time = "2026-01-14T12:55:59.315Z" }, - { url = "https://files.pythonhosted.org/packages/24/d4/7e4be20993dc6a782639625bd2f97f3c66125c7aa80c82426956811cfccf/librt-0.7.8-cp314-cp314t-win_amd64.whl", hash = "sha256:445b7304145e24c60288a2f172b5ce2ca35c0f81605f5299f3fa567e189d2e32", size = 47668, upload-time = "2026-01-14T12:56:00.261Z" }, - { url = "https://files.pythonhosted.org/packages/fc/85/69f92b2a7b3c0f88ffe107c86b952b397004b5b8ea5a81da3d9c04c04422/librt-0.7.8-cp314-cp314t-win_arm64.whl", hash = "sha256:8766ece9de08527deabcd7cb1b4f1a967a385d26e33e536d6d8913db6ef74f06", size = 40550, upload-time = "2026-01-14T12:56:01.542Z" }, +version = "0.8.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8a/3f/4ca7dd7819bf8ff303aca39c3c60e5320e46e766ab7f7dd627d3b9c11bdf/librt-0.8.0.tar.gz", hash = "sha256:cb74cdcbc0103fc988e04e5c58b0b31e8e5dd2babb9182b6f9490488eb36324b", size = 177306, upload-time = "2026-02-12T14:53:54.743Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/e9/018cfd60629e0404e6917943789800aa2231defbea540a17b90cc4547b97/librt-0.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:db63cf3586a24241e89ca1ce0b56baaec9d371a328bd186c529b27c914c9a1ef", size = 65690, upload-time = "2026-02-12T14:51:57.761Z" }, + { url = "https://files.pythonhosted.org/packages/b5/80/8d39980860e4d1c9497ee50e5cd7c4766d8cfd90d105578eae418e8ffcbc/librt-0.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ba9d9e60651615bc614be5e21a82cdb7b1769a029369cf4b4d861e4f19686fb6", size = 68373, upload-time = "2026-02-12T14:51:59.013Z" }, + { url = "https://files.pythonhosted.org/packages/2d/76/6e6f7a443af63977e421bd542551fec4072d9eaba02e671b05b238fe73bc/librt-0.8.0-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:cb4b3ad543084ed79f186741470b251b9d269cd8b03556f15a8d1a99a64b7de5", size = 197091, upload-time = "2026-02-12T14:52:00.642Z" }, + { url = "https://files.pythonhosted.org/packages/14/40/fa064181c231334c9f4cb69eb338132d39510c8928e84beba34b861d0a71/librt-0.8.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3d2720335020219197380ccfa5c895f079ac364b4c429e96952cd6509934d8eb", size = 207350, upload-time = "2026-02-12T14:52:02.32Z" }, + { url = "https://files.pythonhosted.org/packages/50/49/e7f8438dd226305e3e5955d495114ad01448e6a6ffc0303289b4153b5fc5/librt-0.8.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9726305d3e53419d27fc8cdfcd3f9571f0ceae22fa6b5ea1b3662c2e538f833e", size = 219962, upload-time = "2026-02-12T14:52:03.884Z" }, + { url = "https://files.pythonhosted.org/packages/1f/2c/74086fc5d52e77107a3cc80a9a3209be6ad1c9b6bc99969d8d9bbf9fdfe4/librt-0.8.0-cp310-cp310-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:cc3d107f603b5ee7a79b6aa6f166551b99b32fb4a5303c4dfcb4222fc6a0335e", size = 212939, upload-time = "2026-02-12T14:52:05.537Z" }, + { url = "https://files.pythonhosted.org/packages/c8/ae/d6917c0ebec9bc2e0293903d6a5ccc7cdb64c228e529e96520b277318f25/librt-0.8.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:41064a0c07b4cc7a81355ccc305cb097d6027002209ffca51306e65ee8293630", size = 221393, upload-time = "2026-02-12T14:52:07.164Z" }, + { url = "https://files.pythonhosted.org/packages/04/97/15df8270f524ce09ad5c19cbbe0e8f95067582507149a6c90594e7795370/librt-0.8.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:c6e4c10761ddbc0d67d2f6e2753daf99908db85d8b901729bf2bf5eaa60e0567", size = 216721, upload-time = "2026-02-12T14:52:08.857Z" }, + { url = "https://files.pythonhosted.org/packages/c4/52/17cbcf9b7a1bae5016d9d3561bc7169b32c3bd216c47d934d3f270602c0c/librt-0.8.0-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:ba581acad5ac8f33e2ff1746e8a57e001b47c6721873121bf8bbcf7ba8bd3aa4", size = 214790, upload-time = "2026-02-12T14:52:10.033Z" }, + { url = "https://files.pythonhosted.org/packages/2a/2d/010a236e8dc4d717dd545c46fd036dcced2c7ede71ef85cf55325809ff92/librt-0.8.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:bdab762e2c0b48bab76f1a08acb3f4c77afd2123bedac59446aeaaeed3d086cf", size = 237384, upload-time = "2026-02-12T14:52:11.244Z" }, + { url = "https://files.pythonhosted.org/packages/38/14/f1c0eff3df8760dee761029efb72991c554d9f3282f1048e8c3d0eb60997/librt-0.8.0-cp310-cp310-win32.whl", hash = "sha256:6a3146c63220d814c4a2c7d6a1eacc8d5c14aed0ff85115c1dfea868080cd18f", size = 54289, upload-time = "2026-02-12T14:52:12.798Z" }, + { url = "https://files.pythonhosted.org/packages/2f/0b/2684d473e64890882729f91866ed97ccc0a751a0afc3b4bf1a7b57094dbb/librt-0.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:bbebd2bba5c6ae02907df49150e55870fdd7440d727b6192c46b6f754723dde9", size = 61347, upload-time = "2026-02-12T14:52:13.793Z" }, + { url = "https://files.pythonhosted.org/packages/51/e9/42af181c89b65abfd557c1b017cba5b82098eef7bf26d1649d82ce93ccc7/librt-0.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0ce33a9778e294507f3a0e3468eccb6a698b5166df7db85661543eca1cfc5369", size = 65314, upload-time = "2026-02-12T14:52:14.778Z" }, + { url = "https://files.pythonhosted.org/packages/9d/4a/15a847fca119dc0334a4b8012b1e15fdc5fc19d505b71e227eaf1bcdba09/librt-0.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8070aa3368559de81061ef752770d03ca1f5fc9467d4d512d405bd0483bfffe6", size = 68015, upload-time = "2026-02-12T14:52:15.797Z" }, + { url = "https://files.pythonhosted.org/packages/e1/87/ffc8dbd6ab68dd91b736c88529411a6729649d2b74b887f91f3aaff8d992/librt-0.8.0-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:20f73d4fecba969efc15cdefd030e382502d56bb6f1fc66b580cce582836c9fa", size = 194508, upload-time = "2026-02-12T14:52:16.835Z" }, + { url = "https://files.pythonhosted.org/packages/89/92/a7355cea28d6c48ff6ff5083ac4a2a866fb9b07b786aa70d1f1116680cd5/librt-0.8.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a512c88900bdb1d448882f5623a0b1ad27ba81a9bd75dacfe17080b72272ca1f", size = 205630, upload-time = "2026-02-12T14:52:18.58Z" }, + { url = "https://files.pythonhosted.org/packages/ac/5e/54509038d7ac527828db95b8ba1c8f5d2649bc32fd8f39b1718ec9957dce/librt-0.8.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:015e2dde6e096d27c10238bf9f6492ba6c65822dfb69d2bf74c41a8e88b7ddef", size = 218289, upload-time = "2026-02-12T14:52:20.134Z" }, + { url = "https://files.pythonhosted.org/packages/6d/17/0ee0d13685cefee6d6f2d47bb643ddad3c62387e2882139794e6a5f1288a/librt-0.8.0-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:1c25a131013eadd3c600686a0c0333eb2896483cbc7f65baa6a7ee761017aef9", size = 211508, upload-time = "2026-02-12T14:52:21.413Z" }, + { url = "https://files.pythonhosted.org/packages/4b/a8/1714ef6e9325582e3727de3be27e4c1b2f428ea411d09f1396374180f130/librt-0.8.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:21b14464bee0b604d80a638cf1ee3148d84ca4cc163dcdcecb46060c1b3605e4", size = 219129, upload-time = "2026-02-12T14:52:22.61Z" }, + { url = "https://files.pythonhosted.org/packages/89/d3/2d9fe353edff91cdc0ece179348054a6fa61f3de992c44b9477cb973509b/librt-0.8.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:05a3dd3f116747f7e1a2b475ccdc6fb637fd4987126d109e03013a79d40bf9e6", size = 213126, upload-time = "2026-02-12T14:52:23.819Z" }, + { url = "https://files.pythonhosted.org/packages/ad/8e/9f5c60444880f6ad50e3ff7475e5529e787797e7f3ad5432241633733b92/librt-0.8.0-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:fa37f99bff354ff191c6bcdffbc9d7cdd4fc37faccfc9be0ef3a4fd5613977da", size = 212279, upload-time = "2026-02-12T14:52:25.034Z" }, + { url = "https://files.pythonhosted.org/packages/fe/eb/d4a2cfa647da3022ae977f50d7eda1d91f70d7d1883cf958a4b6ef689eab/librt-0.8.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1566dbb9d1eb0987264c9b9460d212e809ba908d2f4a3999383a84d765f2f3f1", size = 234654, upload-time = "2026-02-12T14:52:26.204Z" }, + { url = "https://files.pythonhosted.org/packages/6a/31/26b978861c7983b036a3aea08bdbb2ec32bbaab1ad1d57c5e022be59afc1/librt-0.8.0-cp311-cp311-win32.whl", hash = "sha256:70defb797c4d5402166787a6b3c66dfb3fa7f93d118c0509ffafa35a392f4258", size = 54603, upload-time = "2026-02-12T14:52:27.342Z" }, + { url = "https://files.pythonhosted.org/packages/d0/78/f194ed7c48dacf875677e749c5d0d1d69a9daa7c994314a39466237fb1be/librt-0.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:db953b675079884ffda33d1dca7189fb961b6d372153750beb81880384300817", size = 61730, upload-time = "2026-02-12T14:52:28.31Z" }, + { url = "https://files.pythonhosted.org/packages/97/ee/ad71095478d02137b6f49469dc808c595cfe89b50985f6b39c5345f0faab/librt-0.8.0-cp311-cp311-win_arm64.whl", hash = "sha256:75d1a8cab20b2043f03f7aab730551e9e440adc034d776f15f6f8d582b0a5ad4", size = 52274, upload-time = "2026-02-12T14:52:29.345Z" }, + { url = "https://files.pythonhosted.org/packages/fb/53/f3bc0c4921adb0d4a5afa0656f2c0fbe20e18e3e0295e12985b9a5dc3f55/librt-0.8.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:17269dd2745dbe8e42475acb28e419ad92dfa38214224b1b01020b8cac70b645", size = 66511, upload-time = "2026-02-12T14:52:30.34Z" }, + { url = "https://files.pythonhosted.org/packages/89/4b/4c96357432007c25a1b5e363045373a6c39481e49f6ba05234bb59a839c1/librt-0.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f4617cef654fca552f00ce5ffdf4f4b68770f18950e4246ce94629b789b92467", size = 68628, upload-time = "2026-02-12T14:52:31.491Z" }, + { url = "https://files.pythonhosted.org/packages/47/16/52d75374d1012e8fc709216b5eaa25f471370e2a2331b8be00f18670a6c7/librt-0.8.0-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:5cb11061a736a9db45e3c1293cfcb1e3caf205912dfa085734ba750f2197ff9a", size = 198941, upload-time = "2026-02-12T14:52:32.489Z" }, + { url = "https://files.pythonhosted.org/packages/fc/11/d5dd89e5a2228567b1228d8602d896736247424484db086eea6b8010bcba/librt-0.8.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b4bb00bd71b448f16749909b08a0ff16f58b079e2261c2e1000f2bbb2a4f0a45", size = 210009, upload-time = "2026-02-12T14:52:33.634Z" }, + { url = "https://files.pythonhosted.org/packages/49/d8/fc1a92a77c3020ee08ce2dc48aed4b42ab7c30fb43ce488d388673b0f164/librt-0.8.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:95a719a049f0eefaf1952673223cf00d442952273cbd20cf2ed7ec423a0ef58d", size = 224461, upload-time = "2026-02-12T14:52:34.868Z" }, + { url = "https://files.pythonhosted.org/packages/7f/98/eb923e8b028cece924c246104aa800cf72e02d023a8ad4ca87135b05a2fe/librt-0.8.0-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:bd32add59b58fba3439d48d6f36ac695830388e3da3e92e4fc26d2d02670d19c", size = 217538, upload-time = "2026-02-12T14:52:36.078Z" }, + { url = "https://files.pythonhosted.org/packages/fd/67/24e80ab170674a1d8ee9f9a83081dca4635519dbd0473b8321deecddb5be/librt-0.8.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4f764b2424cb04524ff7a486b9c391e93f93dc1bd8305b2136d25e582e99aa2f", size = 225110, upload-time = "2026-02-12T14:52:37.301Z" }, + { url = "https://files.pythonhosted.org/packages/d8/c7/6fbdcbd1a6e5243c7989c21d68ab967c153b391351174b4729e359d9977f/librt-0.8.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:f04ca50e847abc486fa8f4107250566441e693779a5374ba211e96e238f298b9", size = 217758, upload-time = "2026-02-12T14:52:38.89Z" }, + { url = "https://files.pythonhosted.org/packages/4b/bd/4d6b36669db086e3d747434430073e14def032dd58ad97959bf7e2d06c67/librt-0.8.0-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:9ab3a3475a55b89b87ffd7e6665838e8458e0b596c22e0177e0f961434ec474a", size = 218384, upload-time = "2026-02-12T14:52:40.637Z" }, + { url = "https://files.pythonhosted.org/packages/50/2d/afe966beb0a8f179b132f3e95c8dd90738a23e9ebdba10f89a3f192f9366/librt-0.8.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3e36a8da17134ffc29373775d88c04832f9ecfab1880470661813e6c7991ef79", size = 241187, upload-time = "2026-02-12T14:52:43.55Z" }, + { url = "https://files.pythonhosted.org/packages/02/d0/6172ea4af2b538462785ab1a68e52d5c99cfb9866a7caf00fdf388299734/librt-0.8.0-cp312-cp312-win32.whl", hash = "sha256:4eb5e06ebcc668677ed6389164f52f13f71737fc8be471101fa8b4ce77baeb0c", size = 54914, upload-time = "2026-02-12T14:52:44.676Z" }, + { url = "https://files.pythonhosted.org/packages/d4/cb/ceb6ed6175612a4337ad49fb01ef594712b934b4bc88ce8a63554832eb44/librt-0.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:0a33335eb59921e77c9acc05d0e654e4e32e45b014a4d61517897c11591094f8", size = 62020, upload-time = "2026-02-12T14:52:45.676Z" }, + { url = "https://files.pythonhosted.org/packages/f1/7e/61701acbc67da74ce06ddc7ba9483e81c70f44236b2d00f6a4bfee1aacbf/librt-0.8.0-cp312-cp312-win_arm64.whl", hash = "sha256:24a01c13a2a9bdad20997a4443ebe6e329df063d1978bbe2ebbf637878a46d1e", size = 52443, upload-time = "2026-02-12T14:52:47.218Z" }, + { url = "https://files.pythonhosted.org/packages/6d/32/3edb0bcb4113a9c8bdcd1750663a54565d255027657a5df9d90f13ee07fa/librt-0.8.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7f820210e21e3a8bf8fde2ae3c3d10106d4de9ead28cbfdf6d0f0f41f5b12fa1", size = 66522, upload-time = "2026-02-12T14:52:48.219Z" }, + { url = "https://files.pythonhosted.org/packages/30/ab/e8c3d05e281f5d405ebdcc5bc8ab36df23e1a4b40ac9da8c3eb9928b72b9/librt-0.8.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4831c44b8919e75ca0dfb52052897c1ef59fdae19d3589893fbd068f1e41afbf", size = 68658, upload-time = "2026-02-12T14:52:50.351Z" }, + { url = "https://files.pythonhosted.org/packages/7c/d3/74a206c47b7748bbc8c43942de3ed67de4c231156e148b4f9250869593df/librt-0.8.0-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:88c6e75540f1f10f5e0fc5e87b4b6c290f0e90d1db8c6734f670840494764af8", size = 199287, upload-time = "2026-02-12T14:52:51.938Z" }, + { url = "https://files.pythonhosted.org/packages/fa/29/ef98a9131cf12cb95771d24e4c411fda96c89dc78b09c2de4704877ebee4/librt-0.8.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9646178cd794704d722306c2c920c221abbf080fede3ba539d5afdec16c46dad", size = 210293, upload-time = "2026-02-12T14:52:53.128Z" }, + { url = "https://files.pythonhosted.org/packages/5b/3e/89b4968cb08c53d4c2d8b02517081dfe4b9e07a959ec143d333d76899f6c/librt-0.8.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6e1af31a710e17891d9adf0dbd9a5fcd94901a3922a96499abdbf7ce658f4e01", size = 224801, upload-time = "2026-02-12T14:52:54.367Z" }, + { url = "https://files.pythonhosted.org/packages/6d/28/f38526d501f9513f8b48d78e6be4a241e15dd4b000056dc8b3f06ee9ce5d/librt-0.8.0-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:507e94f4bec00b2f590fbe55f48cd518a208e2474a3b90a60aa8f29136ddbada", size = 218090, upload-time = "2026-02-12T14:52:55.758Z" }, + { url = "https://files.pythonhosted.org/packages/02/ec/64e29887c5009c24dc9c397116c680caffc50286f62bd99c39e3875a2854/librt-0.8.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f1178e0de0c271231a660fbef9be6acdfa1d596803464706862bef6644cc1cae", size = 225483, upload-time = "2026-02-12T14:52:57.375Z" }, + { url = "https://files.pythonhosted.org/packages/ee/16/7850bdbc9f1a32d3feff2708d90c56fc0490b13f1012e438532781aa598c/librt-0.8.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:71fc517efc14f75c2f74b1f0a5d5eb4a8e06aa135c34d18eaf3522f4a53cd62d", size = 218226, upload-time = "2026-02-12T14:52:58.534Z" }, + { url = "https://files.pythonhosted.org/packages/1c/4a/166bffc992d65ddefa7c47052010a87c059b44a458ebaf8f5eba384b0533/librt-0.8.0-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:0583aef7e9a720dd40f26a2ad5a1bf2ccbb90059dac2b32ac516df232c701db3", size = 218755, upload-time = "2026-02-12T14:52:59.701Z" }, + { url = "https://files.pythonhosted.org/packages/da/5d/9aeee038bcc72a9cfaaee934463fe9280a73c5440d36bd3175069d2cb97b/librt-0.8.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5d0f76fc73480d42285c609c0ea74d79856c160fa828ff9aceab574ea4ecfd7b", size = 241617, upload-time = "2026-02-12T14:53:00.966Z" }, + { url = "https://files.pythonhosted.org/packages/64/ff/2bec6b0296b9d0402aa6ec8540aa19ebcb875d669c37800cb43d10d9c3a3/librt-0.8.0-cp313-cp313-win32.whl", hash = "sha256:e79dbc8f57de360f0ed987dc7de7be814b4803ef0e8fc6d3ff86e16798c99935", size = 54966, upload-time = "2026-02-12T14:53:02.042Z" }, + { url = "https://files.pythonhosted.org/packages/08/8d/bf44633b0182996b2c7ea69a03a5c529683fa1f6b8e45c03fe874ff40d56/librt-0.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:25b3e667cbfc9000c4740b282df599ebd91dbdcc1aa6785050e4c1d6be5329ab", size = 62000, upload-time = "2026-02-12T14:53:03.822Z" }, + { url = "https://files.pythonhosted.org/packages/5c/fd/c6472b8e0eac0925001f75e366cf5500bcb975357a65ef1f6b5749389d3a/librt-0.8.0-cp313-cp313-win_arm64.whl", hash = "sha256:e9a3a38eb4134ad33122a6d575e6324831f930a771d951a15ce232e0237412c2", size = 52496, upload-time = "2026-02-12T14:53:04.889Z" }, + { url = "https://files.pythonhosted.org/packages/e0/13/79ebfe30cd273d7c0ce37a5f14dc489c5fb8b722a008983db2cfd57270bb/librt-0.8.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:421765e8c6b18e64d21c8ead315708a56fc24f44075059702e421d164575fdda", size = 66078, upload-time = "2026-02-12T14:53:06.085Z" }, + { url = "https://files.pythonhosted.org/packages/4b/8f/d11eca40b62a8d5e759239a80636386ef88adecb10d1a050b38cc0da9f9e/librt-0.8.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:48f84830a8f8ad7918afd743fd7c4eb558728bceab7b0e38fd5a5cf78206a556", size = 68309, upload-time = "2026-02-12T14:53:07.121Z" }, + { url = "https://files.pythonhosted.org/packages/9c/b4/f12ee70a3596db40ff3c88ec9eaa4e323f3b92f77505b4d900746706ec6a/librt-0.8.0-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:9f09d4884f882baa39a7e36bbf3eae124c4ca2a223efb91e567381d1c55c6b06", size = 196804, upload-time = "2026-02-12T14:53:08.164Z" }, + { url = "https://files.pythonhosted.org/packages/8b/7e/70dbbdc0271fd626abe1671ad117bcd61a9a88cdc6a10ccfbfc703db1873/librt-0.8.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:693697133c3b32aa9b27f040e3691be210e9ac4d905061859a9ed519b1d5a376", size = 206915, upload-time = "2026-02-12T14:53:09.333Z" }, + { url = "https://files.pythonhosted.org/packages/79/13/6b9e05a635d4327608d06b3c1702166e3b3e78315846373446cf90d7b0bf/librt-0.8.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c5512aae4648152abaf4d48b59890503fcbe86e85abc12fb9b096fe948bdd816", size = 221200, upload-time = "2026-02-12T14:53:10.68Z" }, + { url = "https://files.pythonhosted.org/packages/35/6c/e19a3ac53e9414de43a73d7507d2d766cd22d8ca763d29a4e072d628db42/librt-0.8.0-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:995d24caa6bbb34bcdd4a41df98ac6d1af637cfa8975cb0790e47d6623e70e3e", size = 214640, upload-time = "2026-02-12T14:53:12.342Z" }, + { url = "https://files.pythonhosted.org/packages/30/f0/23a78464788619e8c70f090cfd099cce4973eed142c4dccb99fc322283fd/librt-0.8.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:b9aef96d7593584e31ef6ac1eb9775355b0099fee7651fae3a15bc8657b67b52", size = 221980, upload-time = "2026-02-12T14:53:13.603Z" }, + { url = "https://files.pythonhosted.org/packages/03/32/38e21420c5d7aa8a8bd2c7a7d5252ab174a5a8aaec8b5551968979b747bf/librt-0.8.0-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:4f6e975377fbc4c9567cb33ea9ab826031b6c7ec0515bfae66a4fb110d40d6da", size = 215146, upload-time = "2026-02-12T14:53:14.8Z" }, + { url = "https://files.pythonhosted.org/packages/bb/00/bd9ecf38b1824c25240b3ad982fb62c80f0a969e6679091ba2b3afb2b510/librt-0.8.0-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:daae5e955764be8fd70a93e9e5133c75297f8bce1e802e1d3683b98f77e1c5ab", size = 215203, upload-time = "2026-02-12T14:53:16.087Z" }, + { url = "https://files.pythonhosted.org/packages/b9/60/7559bcc5279d37810b98d4a52616febd7b8eef04391714fd6bdf629598b1/librt-0.8.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:7bd68cebf3131bb920d5984f75fe302d758db33264e44b45ad139385662d7bc3", size = 237937, upload-time = "2026-02-12T14:53:17.236Z" }, + { url = "https://files.pythonhosted.org/packages/41/cc/be3e7da88f1abbe2642672af1dc00a0bccece11ca60241b1883f3018d8d5/librt-0.8.0-cp314-cp314-win32.whl", hash = "sha256:1e6811cac1dcb27ca4c74e0ca4a5917a8e06db0d8408d30daee3a41724bfde7a", size = 50685, upload-time = "2026-02-12T14:53:18.888Z" }, + { url = "https://files.pythonhosted.org/packages/38/27/e381d0df182a8f61ef1f6025d8b138b3318cc9d18ad4d5f47c3bf7492523/librt-0.8.0-cp314-cp314-win_amd64.whl", hash = "sha256:178707cda89d910c3b28bf5aa5f69d3d4734e0f6ae102f753ad79edef83a83c7", size = 57872, upload-time = "2026-02-12T14:53:19.942Z" }, + { url = "https://files.pythonhosted.org/packages/c5/0c/ca9dfdf00554a44dea7d555001248269a4bab569e1590a91391feb863fa4/librt-0.8.0-cp314-cp314-win_arm64.whl", hash = "sha256:3e8b77b5f54d0937b26512774916041756c9eb3e66f1031971e626eea49d0bf4", size = 48056, upload-time = "2026-02-12T14:53:21.473Z" }, + { url = "https://files.pythonhosted.org/packages/f2/ed/6cc9c4ad24f90c8e782193c7b4a857408fd49540800613d1356c63567d7b/librt-0.8.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:789911e8fa40a2e82f41120c936b1965f3213c67f5a483fc5a41f5839a05dcbb", size = 68307, upload-time = "2026-02-12T14:53:22.498Z" }, + { url = "https://files.pythonhosted.org/packages/84/d8/0e94292c6b3e00b6eeea39dd44d5703d1ec29b6dafce7eea19dc8f1aedbd/librt-0.8.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:2b37437e7e4ef5e15a297b36ba9e577f73e29564131d86dd75875705e97402b5", size = 70999, upload-time = "2026-02-12T14:53:23.603Z" }, + { url = "https://files.pythonhosted.org/packages/0e/f4/6be1afcbdeedbdbbf54a7c9d73ad43e1bf36897cebf3978308cd64922e02/librt-0.8.0-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:671a6152edf3b924d98a5ed5e6982ec9cb30894085482acadce0975f031d4c5c", size = 220782, upload-time = "2026-02-12T14:53:25.133Z" }, + { url = "https://files.pythonhosted.org/packages/f0/8d/f306e8caa93cfaf5c6c9e0d940908d75dc6af4fd856baa5535c922ee02b1/librt-0.8.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8992ca186a1678107b0af3d0c9303d8c7305981b9914989b9788319ed4d89546", size = 235420, upload-time = "2026-02-12T14:53:27.047Z" }, + { url = "https://files.pythonhosted.org/packages/d6/f2/65d86bd462e9c351326564ca805e8457442149f348496e25ccd94583ffa2/librt-0.8.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:001e5330093d887b8b9165823eca6c5c4db183fe4edea4fdc0680bbac5f46944", size = 246452, upload-time = "2026-02-12T14:53:28.341Z" }, + { url = "https://files.pythonhosted.org/packages/03/94/39c88b503b4cb3fcbdeb3caa29672b6b44ebee8dcc8a54d49839ac280f3f/librt-0.8.0-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:d920789eca7ef71df7f31fd547ec0d3002e04d77f30ba6881e08a630e7b2c30e", size = 238891, upload-time = "2026-02-12T14:53:29.625Z" }, + { url = "https://files.pythonhosted.org/packages/e3/c6/6c0d68190893d01b71b9569b07a1c811e280c0065a791249921c83dc0290/librt-0.8.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:82fb4602d1b3e303a58bfe6165992b5a78d823ec646445356c332cd5f5bbaa61", size = 250249, upload-time = "2026-02-12T14:53:30.93Z" }, + { url = "https://files.pythonhosted.org/packages/52/7a/f715ed9e039035d0ea637579c3c0155ab3709a7046bc408c0fb05d337121/librt-0.8.0-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:4d3e38797eb482485b486898f89415a6ab163bc291476bd95712e42cf4383c05", size = 240642, upload-time = "2026-02-12T14:53:32.174Z" }, + { url = "https://files.pythonhosted.org/packages/c2/3c/609000a333debf5992efe087edc6467c1fdbdddca5b610355569bbea9589/librt-0.8.0-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:a905091a13e0884701226860836d0386b88c72ce5c2fdfba6618e14c72be9f25", size = 239621, upload-time = "2026-02-12T14:53:33.39Z" }, + { url = "https://files.pythonhosted.org/packages/b9/df/87b0673d5c395a8f34f38569c116c93142d4dc7e04af2510620772d6bd4f/librt-0.8.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:375eda7acfce1f15f5ed56cfc960669eefa1ec8732e3e9087c3c4c3f2066759c", size = 262986, upload-time = "2026-02-12T14:53:34.617Z" }, + { url = "https://files.pythonhosted.org/packages/09/7f/6bbbe9dcda649684773aaea78b87fff4d7e59550fbc2877faa83612087a3/librt-0.8.0-cp314-cp314t-win32.whl", hash = "sha256:2ccdd20d9a72c562ffb73098ac411de351b53a6fbb3390903b2d33078ef90447", size = 51328, upload-time = "2026-02-12T14:53:36.15Z" }, + { url = "https://files.pythonhosted.org/packages/bb/f3/e1981ab6fa9b41be0396648b5850267888a752d025313a9e929c4856208e/librt-0.8.0-cp314-cp314t-win_amd64.whl", hash = "sha256:25e82d920d4d62ad741592fcf8d0f3bda0e3fc388a184cb7d2f566c681c5f7b9", size = 58719, upload-time = "2026-02-12T14:53:37.183Z" }, + { url = "https://files.pythonhosted.org/packages/94/d1/433b3c06e78f23486fe4fdd19bc134657eb30997d2054b0dbf52bbf3382e/librt-0.8.0-cp314-cp314t-win_arm64.whl", hash = "sha256:92249938ab744a5890580d3cb2b22042f0dce71cdaa7c1369823df62bedf7cbc", size = 48753, upload-time = "2026-02-12T14:53:38.539Z" }, ] [[package]] From 39afa9665c827c895ff6b7f75b9b736d06898ada Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Thu, 12 Feb 2026 20:34:47 +0100 Subject: [PATCH 25/28] fix group chat import --- .../agent_framework_orchestrations/_group_chat.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_group_chat.py b/python/packages/orchestrations/agent_framework_orchestrations/_group_chat.py index 857b1128b4..a99e221409 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_group_chat.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_group_chat.py @@ -33,7 +33,6 @@ from agent_framework._workflows._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse from agent_framework._workflows._agent_utils import resolve_agent_id from agent_framework._workflows._checkpoint import CheckpointStorage -from agent_framework._workflows._conversation_state import decode_chat_messages, encode_chat_messages from agent_framework._workflows._executor import Executor from agent_framework._workflows._workflow import Workflow from agent_framework._workflows._workflow_builder import WorkflowBuilder @@ -547,7 +546,7 @@ async def _check_agent_terminate_and_yield( async def on_checkpoint_save(self) -> dict[str, Any]: """Capture current orchestrator state for checkpointing.""" state = await super().on_checkpoint_save() - state["cache"] = encode_chat_messages(self._cache) + state["cache"] = self._cache serialized_session = self._session.to_dict() state["session"] = serialized_session @@ -557,7 +556,7 @@ async def on_checkpoint_save(self) -> dict[str, Any]: async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: """Restore executor state from checkpoint.""" await super().on_checkpoint_restore(state) - self._cache = decode_chat_messages(state.get("cache", [])) + self._cache = state.get("cache", []) serialized_session = state.get("session") if serialized_session: self._session = AgentSession.from_dict(serialized_session) From 84ab8003927a50e37df57dfcdbc54ce1945f696c Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Thu, 12 Feb 2026 21:02:07 +0100 Subject: [PATCH 26/28] =?UTF-8?q?Rename=20Thread=E2=86=92Session=20through?= =?UTF-8?q?out,=20fix=20service=5Fsession=5Fid=20propagation,=20remove=20s?= =?UTF-8?q?tale=20AGUIThread?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix: Propagate conversation_id from ChatResponse back to session.service_session_id in both streaming and non-streaming paths in _agents.py - Rename AgentThreadException → AgentSessionException - Remove stale AGUIThread from ag_ui lazy-loader - Rename use_service_thread → use_service_session in ag-ui package - Rename test functions from *_thread_* to *_session_* - Rename sample files from *_thread* to *_session* - Update docstrings and comments: thread → session - Update _mcp.py kwargs filter: add 'session' alongside 'thread' - Fix ContinuationToken docstring example: thread=thread → session=session - Fix _clients.py docstring: 'Agent threads' → 'Agent sessions' --- .../ag-ui/agent_framework_ag_ui/_agent.py | 12 +++++----- .../ag-ui/agent_framework_ag_ui/_run.py | 14 ++++++------ .../ag_ui/test_agent_wrapper_comprehensive.py | 12 +++++----- .../packages/core/agent_framework/_agents.py | 13 +++++++++++ .../packages/core/agent_framework/_clients.py | 2 +- python/packages/core/agent_framework/_mcp.py | 2 +- .../packages/core/agent_framework/_types.py | 2 +- .../core/agent_framework/ag_ui/__init__.py | 1 - .../core/agent_framework/exceptions.py | 4 ++-- .../azure/test_azure_assistants_client.py | 22 +++++++++---------- .../tests/azure/test_azure_chat_client.py | 8 +++---- .../azure/test_azure_responses_client.py | 4 ++-- python/packages/core/tests/core/test_mcp.py | 2 +- .../openai/test_openai_assistants_client.py | 22 +++++++++---------- .../tests/workflow/test_workflow_agent.py | 14 ++++++------ .../samples/01-get-started/03_multi_turn.py | 4 ++-- .../chat_client/azure_ai_chat_client.py | 2 +- .../chat_client/azure_assistants_client.py | 2 +- .../chat_client/azure_chat_client.py | 2 +- .../chat_client/azure_responses_client.py | 2 +- .../chat_client/openai_assistants_client.py | 2 +- .../chat_client/openai_chat_client.py | 2 +- .../chat_client/openai_responses_client.py | 2 +- .../context_providers/mem0/mem0_basic.py | 2 +- .../context_providers/mem0/mem0_oss.py | 2 +- .../{mem0_threads.py => mem0_sessions.py} | 2 +- .../context_providers/redis/redis_basics.py | 2 +- .../{redis_threads.py => redis_sessions.py} | 0 ...y => custom_chat_message_store_session.py} | 0 ...py => redis_chat_message_store_session.py} | 0 ...me_thread.py => suspend_resume_session.py} | 0 .../devui/azure_responses_agent/agent.py | 2 +- .../02-agents/devui/foundry_agent/agent.py | 2 +- .../samples/02-agents/devui/in_memory_mode.py | 2 +- .../devui/weather_agent_azure/agent.py | 2 +- .../02-agents/mcp/agent_as_mcp_server.py | 2 +- .../agent_and_run_level_middleware.py | 2 +- .../02-agents/middleware/chat_middleware.py | 2 +- .../middleware/class_based_middleware.py | 2 +- .../middleware/decorator_middleware.py | 2 +- .../exception_handling_with_middleware.py | 2 +- .../middleware/function_based_middleware.py | 2 +- .../middleware/middleware_termination.py | 2 +- .../override_result_with_middleware.py | 2 +- .../middleware/runtime_context_delegation.py | 2 +- ...ware.py => session_behavior_middleware.py} | 2 +- .../middleware/shared_state_middleware.py | 2 +- .../advanced_manual_setup_console_output.py | 2 +- .../observability/advanced_zero_code.py | 2 +- .../observability/agent_observability.py | 2 +- .../agent_with_foundry_tracing.py | 2 +- .../azure_ai_agent_observability.py | 2 +- .../configure_otel_providers_with_env_var.py | 2 +- ...onfigure_otel_providers_with_parameters.py | 2 +- .../orchestrations/handoff_simple.py | 2 +- .../providers/anthropic/anthropic_basic.py | 2 +- .../providers/azure_ai/azure_ai_basic.py | 2 +- .../azure_ai/azure_ai_provider_methods.py | 2 +- .../azure_ai/azure_ai_use_latest_version.py | 2 +- .../azure_ai_with_existing_conversation.py | 2 +- .../azure_ai_with_explicit_settings.py | 2 +- ...ith_thread.py => azure_ai_with_session.py} | 2 +- .../azure_ai_agent/azure_ai_basic.py | 2 +- .../azure_ai_provider_methods.py | 2 +- ...d.py => azure_ai_with_existing_session.py} | 2 +- .../azure_ai_with_explicit_settings.py | 2 +- .../azure_ai_with_function_tools.py | 2 +- .../azure_ai_with_multiple_tools.py | 2 +- ...ith_thread.py => azure_ai_with_session.py} | 2 +- .../azure_openai/azure_assistants_basic.py | 2 +- ...zure_assistants_with_existing_assistant.py | 2 +- ...azure_assistants_with_explicit_settings.py | 2 +- .../azure_assistants_with_function_tools.py | 2 +- ...ad.py => azure_assistants_with_session.py} | 2 +- .../azure_openai/azure_chat_client_basic.py | 2 +- ...zure_chat_client_with_explicit_settings.py | 2 +- .../azure_chat_client_with_function_tools.py | 2 +- ...d.py => azure_chat_client_with_session.py} | 2 +- .../azure_responses_client_basic.py | 2 +- ...responses_client_with_explicit_settings.py | 2 +- .../azure_responses_client_with_foundry.py | 2 +- ...re_responses_client_with_function_tools.py | 2 +- ...=> azure_responses_client_with_session.py} | 2 +- .../github_copilot/github_copilot_basic.py | 2 +- .../github_copilot_with_session.py | 2 +- .../providers/ollama/ollama_agent_basic.py | 2 +- .../providers/ollama/ollama_chat_client.py | 2 +- .../ollama/ollama_with_openai_chat_client.py | 2 +- .../openai/openai_assistants_basic.py | 2 +- .../openai_assistants_provider_methods.py | 2 +- ...enai_assistants_with_existing_assistant.py | 2 +- ...penai_assistants_with_explicit_settings.py | 2 +- .../openai_assistants_with_function_tools.py | 2 +- ...d.py => openai_assistants_with_session.py} | 2 +- .../openai/openai_chat_client_basic.py | 2 +- ...enai_chat_client_with_explicit_settings.py | 2 +- .../openai_chat_client_with_function_tools.py | 2 +- ....py => openai_chat_client_with_session.py} | 2 +- .../openai/openai_responses_client_basic.py | 2 +- ...responses_client_with_explicit_settings.py | 2 +- ...ai_responses_client_with_function_tools.py | 2 +- ...> openai_responses_client_with_session.py} | 2 +- .../function_invocation_configuration.py | 2 +- .../function_tool_recover_from_failures.py | 2 +- .../tools/function_tool_with_approval.py | 2 +- ...nction_tool_with_approval_and_sessions.py} | 0 .../tools/function_tool_with_kwargs.py | 2 +- .../function_tool_with_session_injection.py | 2 +- ...=> azure_ai_agents_with_shared_session.py} | 0 ...re_chat_agents_tool_calls_with_feedback.py | 2 +- .../agents/handoff_workflow_as_agent.py | 2 +- .../agents/workflow_as_agent_kwargs.py | 2 +- ...d.py => workflow_as_agent_with_session.py} | 0 .../composition/sub_workflow_kwargs.py | 2 +- .../declarative/function_tools/main.py | 2 +- .../agents_with_approval_requests.py | 2 +- .../orchestrations/handoff_simple.py | 2 +- .../state-management/workflow_kwargs.py | 2 +- .../concurrent_builder_tool_approval.py | 2 +- .../group_chat_builder_tool_approval.py | 2 +- .../sequential_builder_tool_approval.py | 2 +- .../02_multi_agent/function_app.py | 2 +- .../05-end-to-end/chatkit-integration/app.py | 2 +- .../m365-agent/m365_agent_demo/app.py | 2 +- .../02_assistant_agent_with_tool.py | 2 +- .../sessions/mem0/mem0_basic.py | 2 +- .../getting_started/sessions/mem0/mem0_oss.py | 2 +- .../sessions/mem0/mem0_sessions.py | 2 +- .../sessions/redis/redis_basics.py | 2 +- 129 files changed, 181 insertions(+), 169 deletions(-) rename python/samples/02-agents/context_providers/mem0/{mem0_threads.py => mem0_sessions.py} (99%) rename python/samples/02-agents/context_providers/redis/{redis_threads.py => redis_sessions.py} (100%) rename python/samples/02-agents/conversations/{custom_chat_message_store_thread.py => custom_chat_message_store_session.py} (100%) rename python/samples/02-agents/conversations/{redis_chat_message_store_thread.py => redis_chat_message_store_session.py} (100%) rename python/samples/02-agents/conversations/{suspend_resume_thread.py => suspend_resume_session.py} (100%) rename python/samples/02-agents/middleware/{thread_behavior_middleware.py => session_behavior_middleware.py} (99%) rename python/samples/02-agents/providers/azure_ai/{azure_ai_with_thread.py => azure_ai_with_session.py} (98%) rename python/samples/02-agents/providers/azure_ai_agent/{azure_ai_with_existing_thread.py => azure_ai_with_existing_session.py} (96%) rename python/samples/02-agents/providers/azure_ai_agent/{azure_ai_with_thread.py => azure_ai_with_session.py} (98%) rename python/samples/02-agents/providers/azure_openai/{azure_assistants_with_thread.py => azure_assistants_with_session.py} (98%) rename python/samples/02-agents/providers/azure_openai/{azure_chat_client_with_thread.py => azure_chat_client_with_session.py} (98%) rename python/samples/02-agents/providers/azure_openai/{azure_responses_client_with_thread.py => azure_responses_client_with_session.py} (98%) rename python/samples/02-agents/providers/openai/{openai_assistants_with_thread.py => openai_assistants_with_session.py} (98%) rename python/samples/02-agents/providers/openai/{openai_chat_client_with_thread.py => openai_chat_client_with_session.py} (98%) rename python/samples/02-agents/providers/openai/{openai_responses_client_with_thread.py => openai_responses_client_with_session.py} (98%) rename python/samples/02-agents/tools/{function_tool_with_approval_and_threads.py => function_tool_with_approval_and_sessions.py} (100%) rename python/samples/03-workflows/agents/{azure_ai_agents_with_shared_thread.py => azure_ai_agents_with_shared_session.py} (100%) rename python/samples/03-workflows/agents/{workflow_as_agent_with_thread.py => workflow_as_agent_with_session.py} (100%) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_agent.py b/python/packages/ag-ui/agent_framework_ag_ui/_agent.py index b7e632dbea..765cde7f73 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_agent.py @@ -18,7 +18,7 @@ def __init__( self, state_schema: Any | None = None, predict_state_config: dict[str, dict[str, str]] | None = None, - use_service_thread: bool = False, + use_service_session: bool = False, require_confirmation: bool = True, ): """Initialize agent configuration. @@ -26,12 +26,12 @@ def __init__( Args: state_schema: Optional state schema for state management; accepts dict or Pydantic model/class predict_state_config: Configuration for predictive state updates - use_service_thread: Whether the agent thread is service-managed + use_service_session: Whether the agent session is service-managed require_confirmation: Whether predictive updates require user confirmation before applying """ self.state_schema = self._normalize_state_schema(state_schema) self.predict_state_config = predict_state_config or {} - self.use_service_thread = use_service_thread + self.use_service_session = use_service_session self.require_confirmation = require_confirmation @staticmethod @@ -77,7 +77,7 @@ def __init__( state_schema: Any | None = None, predict_state_config: dict[str, dict[str, str]] | None = None, require_confirmation: bool = True, - use_service_thread: bool = False, + use_service_session: bool = False, ): """Initialize the AG-UI compatible agent wrapper. @@ -88,7 +88,7 @@ def __init__( state_schema: Optional state schema for state management; accepts dict or Pydantic model/class predict_state_config: Configuration for predictive state updates require_confirmation: Whether predictive updates require user confirmation before applying - use_service_thread: Whether the agent thread is service-managed + use_service_session: Whether the agent session is service-managed """ self.agent = agent self.name = name or getattr(agent, "name", "agent") @@ -97,7 +97,7 @@ def __init__( self.config = AgentConfig( state_schema=state_schema, predict_state_config=predict_state_config, - use_service_thread=use_service_thread, + use_service_session=use_service_session, require_confirmation=require_confirmation, ) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_run.py b/python/packages/ag-ui/agent_framework_ag_ui/_run.py index 69eeba84ff..8ab9fa0956 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_run.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_run.py @@ -809,12 +809,12 @@ async def run_agent_stream( register_additional_client_tools(agent, client_tools) tools = merge_tools(server_tools, client_tools) - # Create thread (with service thread support) - if config.use_service_thread: + # Create session (with service session support) + if config.use_service_session: supplied_thread_id = input_data.get("thread_id") or input_data.get("threadId") - thread = AgentSession(service_session_id=supplied_thread_id) + session = AgentSession(service_session_id=supplied_thread_id) else: - thread = AgentSession() + session = AgentSession() # Inject metadata for AG-UI orchestration (Feature #2: Azure-safe truncation) base_metadata: dict[str, Any] = { @@ -823,16 +823,16 @@ async def run_agent_stream( } if flow.current_state: base_metadata["current_state"] = flow.current_state - thread.metadata = _build_safe_metadata(base_metadata) # type: ignore[attr-defined] + session.metadata = _build_safe_metadata(base_metadata) # type: ignore[attr-defined] # Build run kwargs (Feature #6: Azure store flag when metadata present) - run_kwargs: dict[str, Any] = {"session": thread} + run_kwargs: dict[str, Any] = {"session": session} if tools: run_kwargs["tools"] = tools # Filter out AG-UI internal metadata keys before passing to chat client # These are used internally for orchestration and should not be sent to the LLM provider client_metadata = { - k: v for k, v in (getattr(thread, "metadata", None) or {}).items() if k not in AG_UI_INTERNAL_METADATA_KEYS + k: v for k, v in (getattr(session, "metadata", None) or {}).items() if k not in AG_UI_INTERNAL_METADATA_KEYS } safe_metadata = _build_safe_metadata(client_metadata) if client_metadata else {} if safe_metadata: diff --git a/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py index 165756af39..ae978f7869 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py +++ b/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py @@ -597,8 +597,8 @@ async def stream_fn( assert len(tool_events) == 0 -async def test_agent_with_use_service_thread_is_false(streaming_chat_client_stub): - """Test that when use_service_thread is False, the AgentSession used to run the agent is NOT set to the service session ID.""" +async def test_agent_with_use_service_session_is_false(streaming_chat_client_stub): + """Test that when use_service_session is False, the AgentSession used to run the agent is NOT set to the service session ID.""" from agent_framework.ag_ui import AgentFrameworkAgent request_service_session_id: str | None = None @@ -611,7 +611,7 @@ async def stream_fn( ) agent = Agent(client=streaming_chat_client_stub(stream_fn)) - wrapper = AgentFrameworkAgent(agent=agent, use_service_thread=False) + wrapper = AgentFrameworkAgent(agent=agent, use_service_session=False) input_data = {"messages": [{"role": "user", "content": "Hi"}], "thread_id": "conv_123456"} @@ -621,8 +621,8 @@ async def stream_fn( assert request_service_session_id is None # type: ignore[attr-defined] (service_session_id should be set) -async def test_agent_with_use_service_thread_is_true(streaming_chat_client_stub): - """Test that when use_service_thread is True, the AgentSession used to run the agent is set to the service session ID.""" +async def test_agent_with_use_service_session_is_true(streaming_chat_client_stub): + """Test that when use_service_session is True, the AgentSession used to run the agent is set to the service session ID.""" from agent_framework.ag_ui import AgentFrameworkAgent request_service_session_id: str | None = None @@ -638,7 +638,7 @@ async def stream_fn( ) agent = Agent(client=streaming_chat_client_stub(stream_fn)) - wrapper = AgentFrameworkAgent(agent=agent, use_service_thread=True) + wrapper = AgentFrameworkAgent(agent=agent, use_service_session=True) input_data = {"messages": [{"role": "user", "content": "Hi"}], "thread_id": "conv_123456"} diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index b6dbca8099..4ed1b525a0 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -916,6 +916,15 @@ async def _post_hook(response: AgentResponse) -> None: if message.author_name is None: message.author_name = ctx["agent_name"] + # Propagate conversation_id back to session from streaming updates + sess = ctx["session"] + if sess and not sess.service_session_id and response.raw_representation: + raw_items = response.raw_representation if isinstance(response.raw_representation, list) else [] + for item in raw_items: + if hasattr(item, "conversation_id") and item.conversation_id: + sess.service_session_id = item.conversation_id + break + # Run after_run providers (reverse order) session_context = ctx["session_context"] session_context._response = AgentResponse( # type: ignore[assignment] @@ -1091,6 +1100,10 @@ async def _finalize_response( if message.author_name is None: message.author_name = agent_name + # Propagate conversation_id back to session (e.g. thread ID from Assistants API) + if session and response.conversation_id and not session.service_session_id: + session.service_session_id = response.conversation_id + # Set the response on the context for after_run providers session_context._response = AgentResponse( # type: ignore[assignment] messages=response.messages, diff --git a/python/packages/core/agent_framework/_clients.py b/python/packages/core/agent_framework/_clients.py index 9abdbb5697..57daed6286 100644 --- a/python/packages/core/agent_framework/_clients.py +++ b/python/packages/core/agent_framework/_clients.py @@ -266,7 +266,7 @@ async def _stream(): """Whether this client stores conversation history server-side by default. Clients that use server-side storage (e.g., OpenAI Responses API with ``store=True`` - as default, Azure AI Agent threads) should override this to ``True``. + as default, Azure AI Agent sessions) should override this to ``True``. When ``True``, the agent skips auto-injecting ``InMemoryHistoryProvider`` unless the user explicitly sets ``store=False``. """ diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index 61c8620bf7..b52d2b252b 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -872,7 +872,7 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str: k: v for k, v in kwargs.items() if k - not in {"chat_options", "tools", "tool_choice", "thread", "conversation_id", "options", "response_format"} + not in {"chat_options", "tools", "tool_choice", "session", "thread", "conversation_id", "options", "response_format"} } parser = self.parse_tool_results or _parse_tool_result_from_mcp diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py index 79c41c1023..71a635f1cc 100644 --- a/python/packages/core/agent_framework/_types.py +++ b/python/packages/core/agent_framework/_types.py @@ -1791,7 +1791,7 @@ class ContinuationToken(TypedDict): # Restore and resume token = json.loads(token_json) response = await agent.run( - thread=thread, + session=session, options={"continuation_token": token}, ) """ diff --git a/python/packages/core/agent_framework/ag_ui/__init__.py b/python/packages/core/agent_framework/ag_ui/__init__.py index 13d1e442cd..b469bb8a60 100644 --- a/python/packages/core/agent_framework/ag_ui/__init__.py +++ b/python/packages/core/agent_framework/ag_ui/__init__.py @@ -8,7 +8,6 @@ _IMPORTS = [ "__version__", "AgentFrameworkAgent", - "AGUIThread", "add_agent_framework_fastapi_endpoint", "AGUIChatClient", "AGUIEventConverter", diff --git a/python/packages/core/agent_framework/exceptions.py b/python/packages/core/agent_framework/exceptions.py index 45296ad74c..21e50e571e 100644 --- a/python/packages/core/agent_framework/exceptions.py +++ b/python/packages/core/agent_framework/exceptions.py @@ -49,8 +49,8 @@ class AgentInitializationError(AgentException): pass -class AgentThreadException(AgentException): - """An error occurred while managing the agent thread.""" +class AgentSessionException(AgentException): + """An error occurred while managing the agent session.""" pass diff --git a/python/packages/core/tests/azure/test_azure_assistants_client.py b/python/packages/core/tests/azure/test_azure_assistants_client.py index bb93ca71d4..eff19b27e6 100644 --- a/python/packages/core/tests/azure/test_azure_assistants_client.py +++ b/python/packages/core/tests/azure/test_azure_assistants_client.py @@ -433,8 +433,8 @@ async def test_azure_assistants_agent_basic_run_streaming(): @pytest.mark.flaky @skip_if_azure_integration_tests_disabled -async def test_azure_assistants_agent_thread_persistence(): - """Test Agent thread persistence across runs with AzureOpenAIAssistantsClient.""" +async def test_azure_assistants_agent_session_persistence(): + """Test Agent session persistence across runs with AzureOpenAIAssistantsClient.""" async with Agent( client=AzureOpenAIAssistantsClient(credential=AzureCliCredential()), instructions="You are a helpful assistant with good memory.", @@ -462,10 +462,10 @@ async def test_azure_assistants_agent_thread_persistence(): @pytest.mark.flaky @skip_if_azure_integration_tests_disabled -async def test_azure_assistants_agent_existing_thread_id(): - """Test Agent with existing thread ID to continue conversations across agent instances.""" - # First, create a conversation and capture the thread ID - existing_thread_id = None +async def test_azure_assistants_agent_existing_session_id(): + """Test Agent with existing session ID to continue conversations across agent instances.""" + # First, create a conversation and capture the session ID + existing_session_id = None async with Agent( client=AzureOpenAIAssistantsClient(credential=AzureCliCredential()), @@ -482,18 +482,18 @@ async def test_azure_assistants_agent_existing_thread_id(): assert any(word in response1.text.lower() for word in ["weather", "paris"]) # The session ID is set after the first response - existing_thread_id = session.service_session_id - assert existing_thread_id is not None + existing_session_id = session.service_session_id + assert existing_session_id is not None - # Now continue with the same thread ID in a new agent instance + # Now continue with the same session ID in a new agent instance async with Agent( - client=AzureOpenAIAssistantsClient(thread_id=existing_thread_id, credential=AzureCliCredential()), + client=AzureOpenAIAssistantsClient(thread_id=existing_session_id, credential=AzureCliCredential()), instructions="You are a helpful weather agent.", tools=[get_weather], ) as agent: # Create a session with the existing ID - session = AgentSession(service_session_id=existing_thread_id) + session = AgentSession(service_session_id=existing_session_id) # Ask about the previous conversation response2 = await agent.run("What was the last city I asked about?", session=session) diff --git a/python/packages/core/tests/azure/test_azure_chat_client.py b/python/packages/core/tests/azure/test_azure_chat_client.py index 2c4cf331b3..b9a279d478 100644 --- a/python/packages/core/tests/azure/test_azure_chat_client.py +++ b/python/packages/core/tests/azure/test_azure_chat_client.py @@ -800,8 +800,8 @@ async def test_azure_openai_chat_client_agent_basic_run_streaming(): @pytest.mark.flaky @skip_if_azure_integration_tests_disabled -async def test_azure_openai_chat_client_agent_thread_persistence(): - """Test Azure OpenAI chat client agent thread persistence across runs with AzureOpenAIChatClient.""" +async def test_azure_openai_chat_client_agent_session_persistence(): + """Test Azure OpenAI chat client agent session persistence across runs with AzureOpenAIChatClient.""" async with Agent( client=AzureOpenAIChatClient(credential=AzureCliCredential()), instructions="You are a helpful assistant with good memory.", @@ -825,8 +825,8 @@ async def test_azure_openai_chat_client_agent_thread_persistence(): @pytest.mark.flaky @skip_if_azure_integration_tests_disabled -async def test_azure_openai_chat_client_agent_existing_thread(): - """Test Azure OpenAI chat client agent with existing thread to continue conversations across agent instances.""" +async def test_azure_openai_chat_client_agent_existing_session(): + """Test Azure OpenAI chat client agent with existing session to continue conversations across agent instances.""" # First conversation - capture the session preserved_session = None diff --git a/python/packages/core/tests/azure/test_azure_responses_client.py b/python/packages/core/tests/azure/test_azure_responses_client.py index d4705b3aab..ef8a7df479 100644 --- a/python/packages/core/tests/azure/test_azure_responses_client.py +++ b/python/packages/core/tests/azure/test_azure_responses_client.py @@ -537,8 +537,8 @@ async def test_integration_client_agent_hosted_code_interpreter_tool(): @pytest.mark.flaky @skip_if_azure_integration_tests_disabled -async def test_integration_client_agent_existing_thread(): - """Test Azure Responses Client agent with existing thread to continue conversations across agent instances.""" +async def test_integration_client_agent_existing_session(): + """Test Azure Responses Client agent with existing session to continue conversations across agent instances.""" # First conversation - capture the session preserved_session = None diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index 06376137d6..38cb243412 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -2588,7 +2588,7 @@ class MockResponseFormat(BaseModel): chat_options={"some": "option"}, # Should be filtered tools=[Mock()], # Should be filtered tool_choice="auto", # Should be filtered - thread=Mock(), # Should be filtered + session=Mock(), # Should be filtered conversation_id="conv-123", # Should be filtered options={"metadata": "value"}, # Should be filtered ) diff --git a/python/packages/core/tests/openai/test_openai_assistants_client.py b/python/packages/core/tests/openai/test_openai_assistants_client.py index 3e4ac27131..80e2020d02 100644 --- a/python/packages/core/tests/openai/test_openai_assistants_client.py +++ b/python/packages/core/tests/openai/test_openai_assistants_client.py @@ -1264,8 +1264,8 @@ async def test_openai_assistants_agent_basic_run_streaming(): @pytest.mark.flaky @skip_if_openai_integration_tests_disabled -async def test_openai_assistants_agent_thread_persistence(): - """Test Agent thread persistence across runs with OpenAIAssistantsClient.""" +async def test_openai_assistants_agent_session_persistence(): + """Test Agent session persistence across runs with OpenAIAssistantsClient.""" async with Agent( client=OpenAIAssistantsClient(model_id=INTEGRATION_TEST_MODEL), instructions="You are a helpful assistant with good memory.", @@ -1293,10 +1293,10 @@ async def test_openai_assistants_agent_thread_persistence(): @pytest.mark.flaky @skip_if_openai_integration_tests_disabled -async def test_openai_assistants_agent_existing_thread_id(): - """Test Agent with existing thread ID to continue conversations across agent instances.""" - # First, create a conversation and capture the thread ID - existing_thread_id = None +async def test_openai_assistants_agent_existing_session_id(): + """Test Agent with existing session ID to continue conversations across agent instances.""" + # First, create a conversation and capture the session ID + existing_session_id = None async with Agent( client=OpenAIAssistantsClient(model_id=INTEGRATION_TEST_MODEL), @@ -1313,18 +1313,18 @@ async def test_openai_assistants_agent_existing_thread_id(): assert any(word in response1.text.lower() for word in ["weather", "paris"]) # The session ID is set after the first response - existing_thread_id = session.service_session_id - assert existing_thread_id is not None + existing_session_id = session.service_session_id + assert existing_session_id is not None - # Now continue with the same thread ID in a new agent instance + # Now continue with the same session ID in a new agent instance async with Agent( - client=OpenAIAssistantsClient(thread_id=existing_thread_id), + client=OpenAIAssistantsClient(thread_id=existing_session_id), instructions="You are a helpful weather agent.", tools=[get_weather], ) as agent: # Create a session with the existing ID - session = AgentSession(service_session_id=existing_thread_id) + session = AgentSession(service_session_id=existing_session_id) # Ask about the previous conversation response2 = await agent.run("What was the last city I asked about?", session=session) diff --git a/python/packages/core/tests/workflow/test_workflow_agent.py b/python/packages/core/tests/workflow/test_workflow_agent.py index 8c1066aa42..5adf82dd57 100644 --- a/python/packages/core/tests/workflow/test_workflow_agent.py +++ b/python/packages/core/tests/workflow/test_workflow_agent.py @@ -510,12 +510,12 @@ async def list_yielding_executor(messages: list[Message], ctx: WorkflowContext[N texts = [message.text for message in result.messages] assert texts == ["first message", "second message", "third fourth"] - async def test_thread_conversation_history_included_in_workflow_run(self) -> None: + async def test_session_conversation_history_included_in_workflow_run(self) -> None: """Test that messages provided to agent.run() are passed through to the workflow.""" # Create an executor that captures all received messages capturing_executor = ConversationHistoryCapturingExecutor(id="capturing", streaming=False) workflow = WorkflowBuilder(start_executor=capturing_executor).build() - agent = WorkflowAgent(workflow=workflow, name="Thread History Test Agent") + agent = WorkflowAgent(workflow=workflow, name="Session History Test Agent") # Create a session session = AgentSession() @@ -528,12 +528,12 @@ async def test_thread_conversation_history_included_in_workflow_run(self) -> Non assert len(capturing_executor.received_messages) == 1 assert capturing_executor.received_messages[0].text == "New user question" - async def test_thread_conversation_history_included_in_workflow_stream(self) -> None: + async def test_session_conversation_history_included_in_workflow_stream(self) -> None: """Test that messages provided to agent.run() are passed through when streaming WorkflowAgent.""" # Create an executor that captures all received messages capturing_executor = ConversationHistoryCapturingExecutor(id="capturing_stream") workflow = WorkflowBuilder(start_executor=capturing_executor).build() - agent = WorkflowAgent(workflow=workflow, name="Thread Stream Test Agent") + agent = WorkflowAgent(workflow=workflow, name="Session Stream Test Agent") # Create a session session = AgentSession() @@ -546,11 +546,11 @@ async def test_thread_conversation_history_included_in_workflow_stream(self) -> assert len(capturing_executor.received_messages) == 1 assert capturing_executor.received_messages[0].text == "How are you?" - async def test_empty_thread_works_correctly(self) -> None: + async def test_empty_session_works_correctly(self) -> None: """Test that an empty session (no message store) works correctly.""" - capturing_executor = ConversationHistoryCapturingExecutor(id="empty_thread_test") + capturing_executor = ConversationHistoryCapturingExecutor(id="empty_session_test") workflow = WorkflowBuilder(start_executor=capturing_executor).build() - agent = WorkflowAgent(workflow=workflow, name="Empty Thread Test Agent") + agent = WorkflowAgent(workflow=workflow, name="Empty Session Test Agent") # Create an empty session session = AgentSession() diff --git a/python/samples/01-get-started/03_multi_turn.py b/python/samples/01-get-started/03_multi_turn.py index 4f7d7dacbe..266764c395 100644 --- a/python/samples/01-get-started/03_multi_turn.py +++ b/python/samples/01-get-started/03_multi_turn.py @@ -7,10 +7,10 @@ from azure.identity import AzureCliCredential """ -Multi-Turn Conversations — Use AgentThread to maintain context +Multi-Turn Conversations — Use AgentSession to maintain context This sample shows how to keep conversation history across multiple calls -by reusing the same thread object. +by reusing the same session object. Environment variables: AZURE_AI_PROJECT_ENDPOINT — Your Azure AI Foundry project endpoint diff --git a/python/samples/02-agents/chat_client/azure_ai_chat_client.py b/python/samples/02-agents/chat_client/azure_ai_chat_client.py index 7d07473ba1..236d93f1a7 100644 --- a/python/samples/02-agents/chat_client/azure_ai_chat_client.py +++ b/python/samples/02-agents/chat_client/azure_ai_chat_client.py @@ -17,7 +17,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/chat_client/azure_assistants_client.py b/python/samples/02-agents/chat_client/azure_assistants_client.py index d507b2800e..66034e8eee 100644 --- a/python/samples/02-agents/chat_client/azure_assistants_client.py +++ b/python/samples/02-agents/chat_client/azure_assistants_client.py @@ -17,7 +17,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/chat_client/azure_chat_client.py b/python/samples/02-agents/chat_client/azure_chat_client.py index f1244b2e77..675df29774 100644 --- a/python/samples/02-agents/chat_client/azure_chat_client.py +++ b/python/samples/02-agents/chat_client/azure_chat_client.py @@ -17,7 +17,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/chat_client/azure_responses_client.py b/python/samples/02-agents/chat_client/azure_responses_client.py index 6f41e43e3f..7ab4212a3a 100644 --- a/python/samples/02-agents/chat_client/azure_responses_client.py +++ b/python/samples/02-agents/chat_client/azure_responses_client.py @@ -17,7 +17,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, "The location to get the weather for."], diff --git a/python/samples/02-agents/chat_client/openai_assistants_client.py b/python/samples/02-agents/chat_client/openai_assistants_client.py index ad80e69ae4..7783743950 100644 --- a/python/samples/02-agents/chat_client/openai_assistants_client.py +++ b/python/samples/02-agents/chat_client/openai_assistants_client.py @@ -17,7 +17,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/chat_client/openai_chat_client.py b/python/samples/02-agents/chat_client/openai_chat_client.py index efe06fdd8f..e784c17ae2 100644 --- a/python/samples/02-agents/chat_client/openai_chat_client.py +++ b/python/samples/02-agents/chat_client/openai_chat_client.py @@ -17,7 +17,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/chat_client/openai_responses_client.py b/python/samples/02-agents/chat_client/openai_responses_client.py index f2283c7fa4..ba589e1c2f 100644 --- a/python/samples/02-agents/chat_client/openai_responses_client.py +++ b/python/samples/02-agents/chat_client/openai_responses_client.py @@ -17,7 +17,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/context_providers/mem0/mem0_basic.py b/python/samples/02-agents/context_providers/mem0/mem0_basic.py index f7a3a7f91f..b4e99e0a9f 100644 --- a/python/samples/02-agents/context_providers/mem0/mem0_basic.py +++ b/python/samples/02-agents/context_providers/mem0/mem0_basic.py @@ -9,7 +9,7 @@ from azure.identity.aio import AzureCliCredential -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def retrieve_company_report(company_code: str, detailed: bool) -> str: if company_code != "CNTS": diff --git a/python/samples/02-agents/context_providers/mem0/mem0_oss.py b/python/samples/02-agents/context_providers/mem0/mem0_oss.py index 2178bbfe58..1b03ac5fc1 100644 --- a/python/samples/02-agents/context_providers/mem0/mem0_oss.py +++ b/python/samples/02-agents/context_providers/mem0/mem0_oss.py @@ -10,7 +10,7 @@ from mem0 import AsyncMemory -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def retrieve_company_report(company_code: str, detailed: bool) -> str: if company_code != "CNTS": diff --git a/python/samples/02-agents/context_providers/mem0/mem0_threads.py b/python/samples/02-agents/context_providers/mem0/mem0_sessions.py similarity index 99% rename from python/samples/02-agents/context_providers/mem0/mem0_threads.py rename to python/samples/02-agents/context_providers/mem0/mem0_sessions.py index dd657b4e1d..cc5548e979 100644 --- a/python/samples/02-agents/context_providers/mem0/mem0_threads.py +++ b/python/samples/02-agents/context_providers/mem0/mem0_sessions.py @@ -9,7 +9,7 @@ from azure.identity.aio import AzureCliCredential -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_user_preferences(user_id: str) -> str: """Mock function to get user preferences.""" diff --git a/python/samples/02-agents/context_providers/redis/redis_basics.py b/python/samples/02-agents/context_providers/redis/redis_basics.py index 81238eb171..5f78d65320 100644 --- a/python/samples/02-agents/context_providers/redis/redis_basics.py +++ b/python/samples/02-agents/context_providers/redis/redis_basics.py @@ -37,7 +37,7 @@ from redisvl.utils.vectorize import OpenAITextVectorizer -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def search_flights(origin_airport_code: str, destination_airport_code: str, detailed: bool = False) -> str: """Simulated flight-search tool to demonstrate tool memory. diff --git a/python/samples/02-agents/context_providers/redis/redis_threads.py b/python/samples/02-agents/context_providers/redis/redis_sessions.py similarity index 100% rename from python/samples/02-agents/context_providers/redis/redis_threads.py rename to python/samples/02-agents/context_providers/redis/redis_sessions.py diff --git a/python/samples/02-agents/conversations/custom_chat_message_store_thread.py b/python/samples/02-agents/conversations/custom_chat_message_store_session.py similarity index 100% rename from python/samples/02-agents/conversations/custom_chat_message_store_thread.py rename to python/samples/02-agents/conversations/custom_chat_message_store_session.py diff --git a/python/samples/02-agents/conversations/redis_chat_message_store_thread.py b/python/samples/02-agents/conversations/redis_chat_message_store_session.py similarity index 100% rename from python/samples/02-agents/conversations/redis_chat_message_store_thread.py rename to python/samples/02-agents/conversations/redis_chat_message_store_session.py diff --git a/python/samples/02-agents/conversations/suspend_resume_thread.py b/python/samples/02-agents/conversations/suspend_resume_session.py similarity index 100% rename from python/samples/02-agents/conversations/suspend_resume_thread.py rename to python/samples/02-agents/conversations/suspend_resume_session.py diff --git a/python/samples/02-agents/devui/azure_responses_agent/agent.py b/python/samples/02-agents/devui/azure_responses_agent/agent.py index 2293367027..bb7de3d54d 100644 --- a/python/samples/02-agents/devui/azure_responses_agent/agent.py +++ b/python/samples/02-agents/devui/azure_responses_agent/agent.py @@ -50,7 +50,7 @@ def analyze_content( return f"Analyzing content for: {query}" -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def summarize_document( length: Annotated[str, "Desired summary length: 'brief', 'medium', or 'detailed'"] = "medium", diff --git a/python/samples/02-agents/devui/foundry_agent/agent.py b/python/samples/02-agents/devui/foundry_agent/agent.py index e4f8c17079..59599bce54 100644 --- a/python/samples/02-agents/devui/foundry_agent/agent.py +++ b/python/samples/02-agents/devui/foundry_agent/agent.py @@ -14,7 +14,7 @@ from pydantic import Field -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/devui/in_memory_mode.py b/python/samples/02-agents/devui/in_memory_mode.py index 983164a484..78e0ae18ed 100644 --- a/python/samples/02-agents/devui/in_memory_mode.py +++ b/python/samples/02-agents/devui/in_memory_mode.py @@ -16,7 +16,7 @@ from typing_extensions import Never -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") # Tool functions for the agent @tool(approval_mode="never_require") diff --git a/python/samples/02-agents/devui/weather_agent_azure/agent.py b/python/samples/02-agents/devui/weather_agent_azure/agent.py index a4c531b7ce..38e7e11ce3 100644 --- a/python/samples/02-agents/devui/weather_agent_azure/agent.py +++ b/python/samples/02-agents/devui/weather_agent_azure/agent.py @@ -101,7 +101,7 @@ async def atlantis_location_filter_middleware( await call_next() -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, "The location to get the weather for."], diff --git a/python/samples/02-agents/mcp/agent_as_mcp_server.py b/python/samples/02-agents/mcp/agent_as_mcp_server.py index ad94a0cec6..0a04b7c567 100644 --- a/python/samples/02-agents/mcp/agent_as_mcp_server.py +++ b/python/samples/02-agents/mcp/agent_as_mcp_server.py @@ -32,7 +32,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_specials() -> Annotated[str, "Returns the specials from the menu."]: return """ diff --git a/python/samples/02-agents/middleware/agent_and_run_level_middleware.py b/python/samples/02-agents/middleware/agent_and_run_level_middleware.py index 9ab0eb3692..cf1d97a586 100644 --- a/python/samples/02-agents/middleware/agent_and_run_level_middleware.py +++ b/python/samples/02-agents/middleware/agent_and_run_level_middleware.py @@ -54,7 +54,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/middleware/chat_middleware.py b/python/samples/02-agents/middleware/chat_middleware.py index 3370d56901..f139f11a9e 100644 --- a/python/samples/02-agents/middleware/chat_middleware.py +++ b/python/samples/02-agents/middleware/chat_middleware.py @@ -37,7 +37,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/middleware/class_based_middleware.py b/python/samples/02-agents/middleware/class_based_middleware.py index bad4e08a26..7bdb02cc69 100644 --- a/python/samples/02-agents/middleware/class_based_middleware.py +++ b/python/samples/02-agents/middleware/class_based_middleware.py @@ -34,7 +34,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/middleware/decorator_middleware.py b/python/samples/02-agents/middleware/decorator_middleware.py index 9d90d6205a..5b22b80cc0 100644 --- a/python/samples/02-agents/middleware/decorator_middleware.py +++ b/python/samples/02-agents/middleware/decorator_middleware.py @@ -42,7 +42,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_current_time() -> str: """Get the current time.""" diff --git a/python/samples/02-agents/middleware/exception_handling_with_middleware.py b/python/samples/02-agents/middleware/exception_handling_with_middleware.py index d2042424ca..d8626a095e 100644 --- a/python/samples/02-agents/middleware/exception_handling_with_middleware.py +++ b/python/samples/02-agents/middleware/exception_handling_with_middleware.py @@ -24,7 +24,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def unstable_data_service( query: Annotated[str, Field(description="The data query to execute.")], diff --git a/python/samples/02-agents/middleware/function_based_middleware.py b/python/samples/02-agents/middleware/function_based_middleware.py index 0c839775e3..ad0679219a 100644 --- a/python/samples/02-agents/middleware/function_based_middleware.py +++ b/python/samples/02-agents/middleware/function_based_middleware.py @@ -31,7 +31,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/middleware/middleware_termination.py b/python/samples/02-agents/middleware/middleware_termination.py index 8d99283782..47be212dda 100644 --- a/python/samples/02-agents/middleware/middleware_termination.py +++ b/python/samples/02-agents/middleware/middleware_termination.py @@ -30,7 +30,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/middleware/override_result_with_middleware.py b/python/samples/02-agents/middleware/override_result_with_middleware.py index 0d02d7dbb6..efaae28a9b 100644 --- a/python/samples/02-agents/middleware/override_result_with_middleware.py +++ b/python/samples/02-agents/middleware/override_result_with_middleware.py @@ -39,7 +39,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/middleware/runtime_context_delegation.py b/python/samples/02-agents/middleware/runtime_context_delegation.py index 409202db0f..a27e945a8a 100644 --- a/python/samples/02-agents/middleware/runtime_context_delegation.py +++ b/python/samples/02-agents/middleware/runtime_context_delegation.py @@ -81,7 +81,7 @@ async def inject_context_middleware( runtime_context = SessionContextContainer() -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") async def send_email( to: Annotated[str, Field(description="Recipient email address")], diff --git a/python/samples/02-agents/middleware/thread_behavior_middleware.py b/python/samples/02-agents/middleware/session_behavior_middleware.py similarity index 99% rename from python/samples/02-agents/middleware/thread_behavior_middleware.py rename to python/samples/02-agents/middleware/session_behavior_middleware.py index d20393b456..02f50c98b8 100644 --- a/python/samples/02-agents/middleware/thread_behavior_middleware.py +++ b/python/samples/02-agents/middleware/session_behavior_middleware.py @@ -31,7 +31,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/middleware/shared_state_middleware.py b/python/samples/02-agents/middleware/shared_state_middleware.py index d4953e782e..3fe80f47a4 100644 --- a/python/samples/02-agents/middleware/shared_state_middleware.py +++ b/python/samples/02-agents/middleware/shared_state_middleware.py @@ -27,7 +27,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/observability/advanced_manual_setup_console_output.py b/python/samples/02-agents/observability/advanced_manual_setup_console_output.py index 36e15539ae..c0bfd7473e 100644 --- a/python/samples/02-agents/observability/advanced_manual_setup_console_output.py +++ b/python/samples/02-agents/observability/advanced_manual_setup_console_output.py @@ -66,7 +66,7 @@ def setup_metrics(): set_meter_provider(meter_provider) -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") async def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/observability/advanced_zero_code.py b/python/samples/02-agents/observability/advanced_zero_code.py index b4ee48bdc4..650a838da8 100644 --- a/python/samples/02-agents/observability/advanced_zero_code.py +++ b/python/samples/02-agents/observability/advanced_zero_code.py @@ -40,7 +40,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") async def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/observability/agent_observability.py b/python/samples/02-agents/observability/agent_observability.py index 0cc7700625..f7cf74c66f 100644 --- a/python/samples/02-agents/observability/agent_observability.py +++ b/python/samples/02-agents/observability/agent_observability.py @@ -17,7 +17,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") async def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/observability/agent_with_foundry_tracing.py b/python/samples/02-agents/observability/agent_with_foundry_tracing.py index 242dbd080a..345da453b7 100644 --- a/python/samples/02-agents/observability/agent_with_foundry_tracing.py +++ b/python/samples/02-agents/observability/agent_with_foundry_tracing.py @@ -41,7 +41,7 @@ logger = logging.getLogger(__name__) -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") async def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/observability/azure_ai_agent_observability.py b/python/samples/02-agents/observability/azure_ai_agent_observability.py index 9395edcc38..d3860f39af 100644 --- a/python/samples/02-agents/observability/azure_ai_agent_observability.py +++ b/python/samples/02-agents/observability/azure_ai_agent_observability.py @@ -29,7 +29,7 @@ dotenv.load_dotenv() -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") async def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/observability/configure_otel_providers_with_env_var.py b/python/samples/02-agents/observability/configure_otel_providers_with_env_var.py index 50f066fce1..2b79435df5 100644 --- a/python/samples/02-agents/observability/configure_otel_providers_with_env_var.py +++ b/python/samples/02-agents/observability/configure_otel_providers_with_env_var.py @@ -31,7 +31,7 @@ SCENARIOS = ["client", "client_stream", "tool", "all"] -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") async def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/observability/configure_otel_providers_with_parameters.py b/python/samples/02-agents/observability/configure_otel_providers_with_parameters.py index b75fd42325..5ec2698607 100644 --- a/python/samples/02-agents/observability/configure_otel_providers_with_parameters.py +++ b/python/samples/02-agents/observability/configure_otel_providers_with_parameters.py @@ -31,7 +31,7 @@ SCENARIOS = ["client", "client_stream", "tool", "all"] -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") async def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/orchestrations/handoff_simple.py b/python/samples/02-agents/orchestrations/handoff_simple.py index 33468f1cd9..9bfe73491e 100644 --- a/python/samples/02-agents/orchestrations/handoff_simple.py +++ b/python/samples/02-agents/orchestrations/handoff_simple.py @@ -35,7 +35,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # See: # samples/02-agents/tools/function_tool_with_approval.py -# samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def process_refund(order_number: Annotated[str, "Order number to process refund for"]) -> str: """Simulated function to process a refund for a given order number.""" diff --git a/python/samples/02-agents/providers/anthropic/anthropic_basic.py b/python/samples/02-agents/providers/anthropic/anthropic_basic.py index 2cde199d35..408e129a43 100644 --- a/python/samples/02-agents/providers/anthropic/anthropic_basic.py +++ b/python/samples/02-agents/providers/anthropic/anthropic_basic.py @@ -14,7 +14,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, "The location to get the weather for."], diff --git a/python/samples/02-agents/providers/azure_ai/azure_ai_basic.py b/python/samples/02-agents/providers/azure_ai/azure_ai_basic.py index 7a76cbb500..3661aa71a4 100644 --- a/python/samples/02-agents/providers/azure_ai/azure_ai_basic.py +++ b/python/samples/02-agents/providers/azure_ai/azure_ai_basic.py @@ -19,7 +19,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/azure_ai/azure_ai_provider_methods.py b/python/samples/02-agents/providers/azure_ai/azure_ai_provider_methods.py index 579377921d..0e72530f7d 100644 --- a/python/samples/02-agents/providers/azure_ai/azure_ai_provider_methods.py +++ b/python/samples/02-agents/providers/azure_ai/azure_ai_provider_methods.py @@ -29,7 +29,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/azure_ai/azure_ai_use_latest_version.py b/python/samples/02-agents/providers/azure_ai/azure_ai_use_latest_version.py index 7394fca31c..4cad79c5cb 100644 --- a/python/samples/02-agents/providers/azure_ai/azure_ai_use_latest_version.py +++ b/python/samples/02-agents/providers/azure_ai/azure_ai_use_latest_version.py @@ -20,7 +20,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/azure_ai/azure_ai_with_existing_conversation.py b/python/samples/02-agents/providers/azure_ai/azure_ai_with_existing_conversation.py index 5c2872d2ed..92a31b2835 100644 --- a/python/samples/02-agents/providers/azure_ai/azure_ai_with_existing_conversation.py +++ b/python/samples/02-agents/providers/azure_ai/azure_ai_with_existing_conversation.py @@ -19,7 +19,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/azure_ai/azure_ai_with_explicit_settings.py b/python/samples/02-agents/providers/azure_ai/azure_ai_with_explicit_settings.py index 3b6d3b4b03..c61d5bbb76 100644 --- a/python/samples/02-agents/providers/azure_ai/azure_ai_with_explicit_settings.py +++ b/python/samples/02-agents/providers/azure_ai/azure_ai_with_explicit_settings.py @@ -20,7 +20,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/azure_ai/azure_ai_with_thread.py b/python/samples/02-agents/providers/azure_ai/azure_ai_with_session.py similarity index 98% rename from python/samples/02-agents/providers/azure_ai/azure_ai_with_thread.py rename to python/samples/02-agents/providers/azure_ai/azure_ai_with_session.py index f9a7fbf4fa..d5fe1ba9c3 100644 --- a/python/samples/02-agents/providers/azure_ai/azure_ai_with_thread.py +++ b/python/samples/02-agents/providers/azure_ai/azure_ai_with_session.py @@ -20,7 +20,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production # See: # samples/02-agents/tools/function_tool_with_approval.py -# samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/azure_ai_agent/azure_ai_basic.py b/python/samples/02-agents/providers/azure_ai_agent/azure_ai_basic.py index 5e6dae68fc..0d10337e86 100644 --- a/python/samples/02-agents/providers/azure_ai_agent/azure_ai_basic.py +++ b/python/samples/02-agents/providers/azure_ai_agent/azure_ai_basic.py @@ -17,7 +17,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/azure_ai_agent/azure_ai_provider_methods.py b/python/samples/02-agents/providers/azure_ai_agent/azure_ai_provider_methods.py index debaf1ac45..e8e19b068b 100644 --- a/python/samples/02-agents/providers/azure_ai_agent/azure_ai_provider_methods.py +++ b/python/samples/02-agents/providers/azure_ai_agent/azure_ai_provider_methods.py @@ -21,7 +21,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_existing_thread.py b/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_existing_session.py similarity index 96% rename from python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_existing_thread.py rename to python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_existing_session.py index 7a1b15259b..66451483ac 100644 --- a/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_existing_thread.py +++ b/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_existing_session.py @@ -21,7 +21,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_explicit_settings.py b/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_explicit_settings.py index e8c61b836b..ea088106d0 100644 --- a/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_explicit_settings.py +++ b/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_explicit_settings.py @@ -20,7 +20,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_function_tools.py b/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_function_tools.py index f51a458ae5..2b252af9c5 100644 --- a/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_function_tools.py +++ b/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_function_tools.py @@ -20,7 +20,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_multiple_tools.py b/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_multiple_tools.py index 4b3cfeadb7..4e2112f0a5 100644 --- a/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_multiple_tools.py +++ b/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_multiple_tools.py @@ -35,7 +35,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_time() -> str: """Get the current UTC time.""" diff --git a/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_thread.py b/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_session.py similarity index 98% rename from python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_thread.py rename to python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_session.py index 190c002747..3025aff851 100644 --- a/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_thread.py +++ b/python/samples/02-agents/providers/azure_ai_agent/azure_ai_with_session.py @@ -19,7 +19,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/azure_openai/azure_assistants_basic.py b/python/samples/02-agents/providers/azure_openai/azure_assistants_basic.py index d8a7715533..71fbdcbe9d 100644 --- a/python/samples/02-agents/providers/azure_openai/azure_assistants_basic.py +++ b/python/samples/02-agents/providers/azure_openai/azure_assistants_basic.py @@ -17,7 +17,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/azure_openai/azure_assistants_with_existing_assistant.py b/python/samples/02-agents/providers/azure_openai/azure_assistants_with_existing_assistant.py index 107f9ff7a7..195e9fc26e 100644 --- a/python/samples/02-agents/providers/azure_openai/azure_assistants_with_existing_assistant.py +++ b/python/samples/02-agents/providers/azure_openai/azure_assistants_with_existing_assistant.py @@ -21,7 +21,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/azure_openai/azure_assistants_with_explicit_settings.py b/python/samples/02-agents/providers/azure_openai/azure_assistants_with_explicit_settings.py index edf8c79f9a..c9c4cee118 100644 --- a/python/samples/02-agents/providers/azure_openai/azure_assistants_with_explicit_settings.py +++ b/python/samples/02-agents/providers/azure_openai/azure_assistants_with_explicit_settings.py @@ -20,7 +20,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/azure_openai/azure_assistants_with_function_tools.py b/python/samples/02-agents/providers/azure_openai/azure_assistants_with_function_tools.py index 7f75c881e0..913332e953 100644 --- a/python/samples/02-agents/providers/azure_openai/azure_assistants_with_function_tools.py +++ b/python/samples/02-agents/providers/azure_openai/azure_assistants_with_function_tools.py @@ -20,7 +20,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/azure_openai/azure_assistants_with_thread.py b/python/samples/02-agents/providers/azure_openai/azure_assistants_with_session.py similarity index 98% rename from python/samples/02-agents/providers/azure_openai/azure_assistants_with_thread.py rename to python/samples/02-agents/providers/azure_openai/azure_assistants_with_session.py index edc9b2edbb..9c4bd6e235 100644 --- a/python/samples/02-agents/providers/azure_openai/azure_assistants_with_thread.py +++ b/python/samples/02-agents/providers/azure_openai/azure_assistants_with_session.py @@ -19,7 +19,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/azure_openai/azure_chat_client_basic.py b/python/samples/02-agents/providers/azure_openai/azure_chat_client_basic.py index 3043a0790d..c55e8682aa 100644 --- a/python/samples/02-agents/providers/azure_openai/azure_chat_client_basic.py +++ b/python/samples/02-agents/providers/azure_openai/azure_chat_client_basic.py @@ -19,7 +19,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/azure_openai/azure_chat_client_with_explicit_settings.py b/python/samples/02-agents/providers/azure_openai/azure_chat_client_with_explicit_settings.py index a6d6b42c55..ac0a1af782 100644 --- a/python/samples/02-agents/providers/azure_openai/azure_chat_client_with_explicit_settings.py +++ b/python/samples/02-agents/providers/azure_openai/azure_chat_client_with_explicit_settings.py @@ -20,7 +20,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/azure_openai/azure_chat_client_with_function_tools.py b/python/samples/02-agents/providers/azure_openai/azure_chat_client_with_function_tools.py index 182a2d8bda..4b42ebecef 100644 --- a/python/samples/02-agents/providers/azure_openai/azure_chat_client_with_function_tools.py +++ b/python/samples/02-agents/providers/azure_openai/azure_chat_client_with_function_tools.py @@ -20,7 +20,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/azure_openai/azure_chat_client_with_thread.py b/python/samples/02-agents/providers/azure_openai/azure_chat_client_with_session.py similarity index 98% rename from python/samples/02-agents/providers/azure_openai/azure_chat_client_with_thread.py rename to python/samples/02-agents/providers/azure_openai/azure_chat_client_with_session.py index 04fdea8162..c5993431fb 100644 --- a/python/samples/02-agents/providers/azure_openai/azure_chat_client_with_thread.py +++ b/python/samples/02-agents/providers/azure_openai/azure_chat_client_with_session.py @@ -19,7 +19,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/azure_openai/azure_responses_client_basic.py b/python/samples/02-agents/providers/azure_openai/azure_responses_client_basic.py index 98353c63a7..c638426e4d 100644 --- a/python/samples/02-agents/providers/azure_openai/azure_responses_client_basic.py +++ b/python/samples/02-agents/providers/azure_openai/azure_responses_client_basic.py @@ -19,7 +19,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/azure_openai/azure_responses_client_with_explicit_settings.py b/python/samples/02-agents/providers/azure_openai/azure_responses_client_with_explicit_settings.py index ee27fa8160..57498b9e1f 100644 --- a/python/samples/02-agents/providers/azure_openai/azure_responses_client_with_explicit_settings.py +++ b/python/samples/02-agents/providers/azure_openai/azure_responses_client_with_explicit_settings.py @@ -20,7 +20,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/azure_openai/azure_responses_client_with_foundry.py b/python/samples/02-agents/providers/azure_openai/azure_responses_client_with_foundry.py index ecc53e6f3f..36b6572427 100644 --- a/python/samples/02-agents/providers/azure_openai/azure_responses_client_with_foundry.py +++ b/python/samples/02-agents/providers/azure_openai/azure_responses_client_with_foundry.py @@ -28,7 +28,7 @@ load_dotenv() # Load environment variables from .env file if present -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/azure_openai/azure_responses_client_with_function_tools.py b/python/samples/02-agents/providers/azure_openai/azure_responses_client_with_function_tools.py index dacdf69c63..32fc56ed97 100644 --- a/python/samples/02-agents/providers/azure_openai/azure_responses_client_with_function_tools.py +++ b/python/samples/02-agents/providers/azure_openai/azure_responses_client_with_function_tools.py @@ -20,7 +20,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/azure_openai/azure_responses_client_with_thread.py b/python/samples/02-agents/providers/azure_openai/azure_responses_client_with_session.py similarity index 98% rename from python/samples/02-agents/providers/azure_openai/azure_responses_client_with_thread.py rename to python/samples/02-agents/providers/azure_openai/azure_responses_client_with_session.py index 2de40871a4..a406b9969d 100644 --- a/python/samples/02-agents/providers/azure_openai/azure_responses_client_with_thread.py +++ b/python/samples/02-agents/providers/azure_openai/azure_responses_client_with_session.py @@ -19,7 +19,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/github_copilot/github_copilot_basic.py b/python/samples/02-agents/providers/github_copilot/github_copilot_basic.py index 9d5aa78b5c..6faacea67c 100644 --- a/python/samples/02-agents/providers/github_copilot/github_copilot_basic.py +++ b/python/samples/02-agents/providers/github_copilot/github_copilot_basic.py @@ -22,7 +22,7 @@ from pydantic import Field -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/github_copilot/github_copilot_with_session.py b/python/samples/02-agents/providers/github_copilot/github_copilot_with_session.py index c07395ba6a..4d386bfa19 100644 --- a/python/samples/02-agents/providers/github_copilot/github_copilot_with_session.py +++ b/python/samples/02-agents/providers/github_copilot/github_copilot_with_session.py @@ -17,7 +17,7 @@ from pydantic import Field -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/ollama/ollama_agent_basic.py b/python/samples/02-agents/providers/ollama/ollama_agent_basic.py index 698a7f9009..92b7f47156 100644 --- a/python/samples/02-agents/providers/ollama/ollama_agent_basic.py +++ b/python/samples/02-agents/providers/ollama/ollama_agent_basic.py @@ -19,7 +19,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_time(location: str) -> str: """Get the current time.""" diff --git a/python/samples/02-agents/providers/ollama/ollama_chat_client.py b/python/samples/02-agents/providers/ollama/ollama_chat_client.py index 88f887479b..636b7e4aa2 100644 --- a/python/samples/02-agents/providers/ollama/ollama_chat_client.py +++ b/python/samples/02-agents/providers/ollama/ollama_chat_client.py @@ -19,7 +19,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_time(): """Get the current time.""" diff --git a/python/samples/02-agents/providers/ollama/ollama_with_openai_chat_client.py b/python/samples/02-agents/providers/ollama/ollama_with_openai_chat_client.py index 200549b76e..140c338192 100644 --- a/python/samples/02-agents/providers/ollama/ollama_with_openai_chat_client.py +++ b/python/samples/02-agents/providers/ollama/ollama_with_openai_chat_client.py @@ -21,7 +21,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, "The location to get the weather for."], diff --git a/python/samples/02-agents/providers/openai/openai_assistants_basic.py b/python/samples/02-agents/providers/openai/openai_assistants_basic.py index 768631ea98..573b3c260c 100644 --- a/python/samples/02-agents/providers/openai/openai_assistants_basic.py +++ b/python/samples/02-agents/providers/openai/openai_assistants_basic.py @@ -20,7 +20,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/openai/openai_assistants_provider_methods.py b/python/samples/02-agents/providers/openai/openai_assistants_provider_methods.py index 588c2f48dc..f74e064002 100644 --- a/python/samples/02-agents/providers/openai/openai_assistants_provider_methods.py +++ b/python/samples/02-agents/providers/openai/openai_assistants_provider_methods.py @@ -22,7 +22,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/openai/openai_assistants_with_existing_assistant.py b/python/samples/02-agents/providers/openai/openai_assistants_with_existing_assistant.py index b1751885ce..4e432f88b2 100644 --- a/python/samples/02-agents/providers/openai/openai_assistants_with_existing_assistant.py +++ b/python/samples/02-agents/providers/openai/openai_assistants_with_existing_assistant.py @@ -18,7 +18,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/openai/openai_assistants_with_explicit_settings.py b/python/samples/02-agents/providers/openai/openai_assistants_with_explicit_settings.py index 2ef830da78..aa7b5db6f6 100644 --- a/python/samples/02-agents/providers/openai/openai_assistants_with_explicit_settings.py +++ b/python/samples/02-agents/providers/openai/openai_assistants_with_explicit_settings.py @@ -20,7 +20,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/openai/openai_assistants_with_function_tools.py b/python/samples/02-agents/providers/openai/openai_assistants_with_function_tools.py index aa0e6dc289..2b406170a6 100644 --- a/python/samples/02-agents/providers/openai/openai_assistants_with_function_tools.py +++ b/python/samples/02-agents/providers/openai/openai_assistants_with_function_tools.py @@ -19,7 +19,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/openai/openai_assistants_with_thread.py b/python/samples/02-agents/providers/openai/openai_assistants_with_session.py similarity index 98% rename from python/samples/02-agents/providers/openai/openai_assistants_with_thread.py rename to python/samples/02-agents/providers/openai/openai_assistants_with_session.py index 155b5d6a73..39e13745a2 100644 --- a/python/samples/02-agents/providers/openai/openai_assistants_with_thread.py +++ b/python/samples/02-agents/providers/openai/openai_assistants_with_session.py @@ -20,7 +20,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/openai/openai_chat_client_basic.py b/python/samples/02-agents/providers/openai/openai_chat_client_basic.py index 167cdb6f4c..fb7bd42613 100644 --- a/python/samples/02-agents/providers/openai/openai_chat_client_basic.py +++ b/python/samples/02-agents/providers/openai/openai_chat_client_basic.py @@ -17,7 +17,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, "The location to get the weather for."], diff --git a/python/samples/02-agents/providers/openai/openai_chat_client_with_explicit_settings.py b/python/samples/02-agents/providers/openai/openai_chat_client_with_explicit_settings.py index a1dc84fb02..e5b85c31fa 100644 --- a/python/samples/02-agents/providers/openai/openai_chat_client_with_explicit_settings.py +++ b/python/samples/02-agents/providers/openai/openai_chat_client_with_explicit_settings.py @@ -19,7 +19,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/openai/openai_chat_client_with_function_tools.py b/python/samples/02-agents/providers/openai/openai_chat_client_with_function_tools.py index c8e5f20f7d..d66a5cf778 100644 --- a/python/samples/02-agents/providers/openai/openai_chat_client_with_function_tools.py +++ b/python/samples/02-agents/providers/openai/openai_chat_client_with_function_tools.py @@ -19,7 +19,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/openai/openai_chat_client_with_thread.py b/python/samples/02-agents/providers/openai/openai_chat_client_with_session.py similarity index 98% rename from python/samples/02-agents/providers/openai/openai_chat_client_with_thread.py rename to python/samples/02-agents/providers/openai/openai_chat_client_with_session.py index 6470b3a815..a5bbf20b63 100644 --- a/python/samples/02-agents/providers/openai/openai_chat_client_with_thread.py +++ b/python/samples/02-agents/providers/openai/openai_chat_client_with_session.py @@ -18,7 +18,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/openai/openai_responses_client_basic.py b/python/samples/02-agents/providers/openai/openai_responses_client_basic.py index 42eddc694f..fa4766e575 100644 --- a/python/samples/02-agents/providers/openai/openai_responses_client_basic.py +++ b/python/samples/02-agents/providers/openai/openai_responses_client_basic.py @@ -68,7 +68,7 @@ async def security_and_override_middleware( # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/openai/openai_responses_client_with_explicit_settings.py b/python/samples/02-agents/providers/openai/openai_responses_client_with_explicit_settings.py index 428e6dbb79..9aeba0f009 100644 --- a/python/samples/02-agents/providers/openai/openai_responses_client_with_explicit_settings.py +++ b/python/samples/02-agents/providers/openai/openai_responses_client_with_explicit_settings.py @@ -19,7 +19,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/openai/openai_responses_client_with_function_tools.py b/python/samples/02-agents/providers/openai/openai_responses_client_with_function_tools.py index faba202798..5884467650 100644 --- a/python/samples/02-agents/providers/openai/openai_responses_client_with_function_tools.py +++ b/python/samples/02-agents/providers/openai/openai_responses_client_with_function_tools.py @@ -19,7 +19,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/providers/openai/openai_responses_client_with_thread.py b/python/samples/02-agents/providers/openai/openai_responses_client_with_session.py similarity index 98% rename from python/samples/02-agents/providers/openai/openai_responses_client_with_thread.py rename to python/samples/02-agents/providers/openai/openai_responses_client_with_session.py index 4000db96cc..30866cc5da 100644 --- a/python/samples/02-agents/providers/openai/openai_responses_client_with_thread.py +++ b/python/samples/02-agents/providers/openai/openai_responses_client_with_session.py @@ -18,7 +18,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/tools/function_invocation_configuration.py b/python/samples/02-agents/tools/function_invocation_configuration.py index b116e53197..81318484e6 100644 --- a/python/samples/02-agents/tools/function_invocation_configuration.py +++ b/python/samples/02-agents/tools/function_invocation_configuration.py @@ -14,7 +14,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def add( x: Annotated[int, "First number"], diff --git a/python/samples/02-agents/tools/function_tool_recover_from_failures.py b/python/samples/02-agents/tools/function_tool_recover_from_failures.py index 9c506d1304..8f1cde5d14 100644 --- a/python/samples/02-agents/tools/function_tool_recover_from_failures.py +++ b/python/samples/02-agents/tools/function_tool_recover_from_failures.py @@ -14,7 +14,7 @@ """ -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def greet(name: Annotated[str, "Name to greet"]) -> str: """Greet someone.""" diff --git a/python/samples/02-agents/tools/function_tool_with_approval.py b/python/samples/02-agents/tools/function_tool_with_approval.py index 21cd8bfe24..e87b3da462 100644 --- a/python/samples/02-agents/tools/function_tool_with_approval.py +++ b/python/samples/02-agents/tools/function_tool_with_approval.py @@ -20,7 +20,7 @@ conditions = ["sunny", "cloudy", "raining", "snowing", "clear"] -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather(location: Annotated[str, "The city and state, e.g. San Francisco, CA"]) -> str: """Get the current weather for a given location.""" diff --git a/python/samples/02-agents/tools/function_tool_with_approval_and_threads.py b/python/samples/02-agents/tools/function_tool_with_approval_and_sessions.py similarity index 100% rename from python/samples/02-agents/tools/function_tool_with_approval_and_threads.py rename to python/samples/02-agents/tools/function_tool_with_approval_and_sessions.py diff --git a/python/samples/02-agents/tools/function_tool_with_kwargs.py b/python/samples/02-agents/tools/function_tool_with_kwargs.py index 400d9f65f1..15dd597354 100644 --- a/python/samples/02-agents/tools/function_tool_with_kwargs.py +++ b/python/samples/02-agents/tools/function_tool_with_kwargs.py @@ -20,7 +20,7 @@ # Define the function tool with **kwargs to accept injected arguments -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/02-agents/tools/function_tool_with_session_injection.py b/python/samples/02-agents/tools/function_tool_with_session_injection.py index bc89ca80ec..5e5a8322ac 100644 --- a/python/samples/02-agents/tools/function_tool_with_session_injection.py +++ b/python/samples/02-agents/tools/function_tool_with_session_injection.py @@ -16,7 +16,7 @@ # Define the function tool with **kwargs -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") async def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/03-workflows/agents/azure_ai_agents_with_shared_thread.py b/python/samples/03-workflows/agents/azure_ai_agents_with_shared_session.py similarity index 100% rename from python/samples/03-workflows/agents/azure_ai_agents_with_shared_thread.py rename to python/samples/03-workflows/agents/azure_ai_agents_with_shared_session.py diff --git a/python/samples/03-workflows/agents/azure_chat_agents_tool_calls_with_feedback.py b/python/samples/03-workflows/agents/azure_chat_agents_tool_calls_with_feedback.py index 96f25f65f7..72ecda0609 100644 --- a/python/samples/03-workflows/agents/azure_chat_agents_tool_calls_with_feedback.py +++ b/python/samples/03-workflows/agents/azure_chat_agents_tool_calls_with_feedback.py @@ -52,7 +52,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py and -# samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def fetch_product_brief( product_name: Annotated[str, Field(description="Product name to look up.")], diff --git a/python/samples/03-workflows/agents/handoff_workflow_as_agent.py b/python/samples/03-workflows/agents/handoff_workflow_as_agent.py index 0154ed4ce3..9eaa0549ec 100644 --- a/python/samples/03-workflows/agents/handoff_workflow_as_agent.py +++ b/python/samples/03-workflows/agents/handoff_workflow_as_agent.py @@ -40,7 +40,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # See: # samples/02-agents/tools/function_tool_with_approval.py -# samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def process_refund(order_number: Annotated[str, "Order number to process refund for"]) -> str: """Simulated function to process a refund for a given order number.""" diff --git a/python/samples/03-workflows/agents/workflow_as_agent_kwargs.py b/python/samples/03-workflows/agents/workflow_as_agent_kwargs.py index df61734cb8..539fdfc540 100644 --- a/python/samples/03-workflows/agents/workflow_as_agent_kwargs.py +++ b/python/samples/03-workflows/agents/workflow_as_agent_kwargs.py @@ -38,7 +38,7 @@ # Define tools that accept custom context via **kwargs # NOTE: approval_mode="never_require" is for sample brevity. # Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and -# samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_user_data( query: Annotated[str, Field(description="What user data to retrieve")], diff --git a/python/samples/03-workflows/agents/workflow_as_agent_with_thread.py b/python/samples/03-workflows/agents/workflow_as_agent_with_session.py similarity index 100% rename from python/samples/03-workflows/agents/workflow_as_agent_with_thread.py rename to python/samples/03-workflows/agents/workflow_as_agent_with_session.py diff --git a/python/samples/03-workflows/composition/sub_workflow_kwargs.py b/python/samples/03-workflows/composition/sub_workflow_kwargs.py index 71e5ae54ad..47950ea087 100644 --- a/python/samples/03-workflows/composition/sub_workflow_kwargs.py +++ b/python/samples/03-workflows/composition/sub_workflow_kwargs.py @@ -36,7 +36,7 @@ # Define tools that access custom context via **kwargs # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py and -# samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_authenticated_data( resource: Annotated[str, "The resource to fetch"], diff --git a/python/samples/03-workflows/declarative/function_tools/main.py b/python/samples/03-workflows/declarative/function_tools/main.py index 4bda15b655..521639b1ad 100644 --- a/python/samples/03-workflows/declarative/function_tools/main.py +++ b/python/samples/03-workflows/declarative/function_tools/main.py @@ -39,7 +39,7 @@ class MenuItem: ] -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_menu() -> list[dict[str, Any]]: """Get all menu items.""" diff --git a/python/samples/03-workflows/human-in-the-loop/agents_with_approval_requests.py b/python/samples/03-workflows/human-in-the-loop/agents_with_approval_requests.py index 4fcf279b9c..7d0c050457 100644 --- a/python/samples/03-workflows/human-in-the-loop/agents_with_approval_requests.py +++ b/python/samples/03-workflows/human-in-the-loop/agents_with_approval_requests.py @@ -57,7 +57,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # See: # samples/02-agents/tools/function_tool_with_approval.py -# samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_current_date() -> str: """Get the current date in YYYY-MM-DD format.""" diff --git a/python/samples/03-workflows/orchestrations/handoff_simple.py b/python/samples/03-workflows/orchestrations/handoff_simple.py index b2f40f438f..d4b2e5edf5 100644 --- a/python/samples/03-workflows/orchestrations/handoff_simple.py +++ b/python/samples/03-workflows/orchestrations/handoff_simple.py @@ -35,7 +35,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # See: # samples/getting_started/tools/function_tool_with_approval.py -# samples/getting_started/tools/function_tool_with_approval_and_threads.py. +# samples/getting_started/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def process_refund(order_number: Annotated[str, "Order number to process refund for"]) -> str: """Simulated function to process a refund for a given order number.""" diff --git a/python/samples/03-workflows/state-management/workflow_kwargs.py b/python/samples/03-workflows/state-management/workflow_kwargs.py index 8835ee365b..c7f3562fc8 100644 --- a/python/samples/03-workflows/state-management/workflow_kwargs.py +++ b/python/samples/03-workflows/state-management/workflow_kwargs.py @@ -32,7 +32,7 @@ # Define tools that accept custom context via **kwargs # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_user_data( query: Annotated[str, Field(description="What user data to retrieve")], diff --git a/python/samples/03-workflows/tool-approval/concurrent_builder_tool_approval.py b/python/samples/03-workflows/tool-approval/concurrent_builder_tool_approval.py index 227306eff0..94ca3aab88 100644 --- a/python/samples/03-workflows/tool-approval/concurrent_builder_tool_approval.py +++ b/python/samples/03-workflows/tool-approval/concurrent_builder_tool_approval.py @@ -50,7 +50,7 @@ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # See: # samples/02-agents/tools/function_tool_with_approval.py -# samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_stock_price(symbol: Annotated[str, "The stock ticker symbol"]) -> str: """Get the current stock price for a given symbol.""" diff --git a/python/samples/03-workflows/tool-approval/group_chat_builder_tool_approval.py b/python/samples/03-workflows/tool-approval/group_chat_builder_tool_approval.py index e877abdd3c..d7472d13ca 100644 --- a/python/samples/03-workflows/tool-approval/group_chat_builder_tool_approval.py +++ b/python/samples/03-workflows/tool-approval/group_chat_builder_tool_approval.py @@ -48,7 +48,7 @@ # 1. Define tools for different agents # NOTE: approval_mode="never_require" is for sample brevity. # Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def run_tests(test_suite: Annotated[str, "Name of the test suite to run"]) -> str: """Run automated tests for the application.""" diff --git a/python/samples/03-workflows/tool-approval/sequential_builder_tool_approval.py b/python/samples/03-workflows/tool-approval/sequential_builder_tool_approval.py index 10a05343a6..d9916801f4 100644 --- a/python/samples/03-workflows/tool-approval/sequential_builder_tool_approval.py +++ b/python/samples/03-workflows/tool-approval/sequential_builder_tool_approval.py @@ -58,7 +58,7 @@ def execute_database_query( # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py and -# samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_database_schema() -> str: """Get the current database schema. Does not require approval.""" diff --git a/python/samples/04-hosting/azure_functions/02_multi_agent/function_app.py b/python/samples/04-hosting/azure_functions/02_multi_agent/function_app.py index 2a3bd0d420..419b91b779 100644 --- a/python/samples/04-hosting/azure_functions/02_multi_agent/function_app.py +++ b/python/samples/04-hosting/azure_functions/02_multi_agent/function_app.py @@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather(location: str) -> dict[str, Any]: """Get current weather for a location.""" diff --git a/python/samples/05-end-to-end/chatkit-integration/app.py b/python/samples/05-end-to-end/chatkit-integration/app.py index ab96a9bf26..a22698b085 100644 --- a/python/samples/05-end-to-end/chatkit-integration/app.py +++ b/python/samples/05-end-to-end/chatkit-integration/app.py @@ -141,7 +141,7 @@ async def stream_widget( yield ThreadItemDoneEvent(type="thread.item.done", item=widget_item) -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/05-end-to-end/m365-agent/m365_agent_demo/app.py b/python/samples/05-end-to-end/m365-agent/m365_agent_demo/app.py index 81e28936bc..a33a487b34 100644 --- a/python/samples/05-end-to-end/m365-agent/m365_agent_demo/app.py +++ b/python/samples/05-end-to-end/m365-agent/m365_agent_demo/app.py @@ -79,7 +79,7 @@ def load_app_config() -> AppConfig: return AppConfig(use_anonymous_mode=use_anonymous_mode, port=port, agents_sdk_config=agents_sdk_config) -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], diff --git a/python/samples/autogen-migration/single_agent/02_assistant_agent_with_tool.py b/python/samples/autogen-migration/single_agent/02_assistant_agent_with_tool.py index 134eb6ef59..0f2f254e07 100644 --- a/python/samples/autogen-migration/single_agent/02_assistant_agent_with_tool.py +++ b/python/samples/autogen-migration/single_agent/02_assistant_agent_with_tool.py @@ -62,7 +62,7 @@ async def run_agent_framework() -> None: from agent_framework.openai import OpenAIChatClient # Define tool with @tool decorator (automatic schema inference) - # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. + # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather(location: str) -> str: """Get the weather for a location. diff --git a/python/samples/getting_started/sessions/mem0/mem0_basic.py b/python/samples/getting_started/sessions/mem0/mem0_basic.py index f7a3a7f91f..b4e99e0a9f 100644 --- a/python/samples/getting_started/sessions/mem0/mem0_basic.py +++ b/python/samples/getting_started/sessions/mem0/mem0_basic.py @@ -9,7 +9,7 @@ from azure.identity.aio import AzureCliCredential -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def retrieve_company_report(company_code: str, detailed: bool) -> str: if company_code != "CNTS": diff --git a/python/samples/getting_started/sessions/mem0/mem0_oss.py b/python/samples/getting_started/sessions/mem0/mem0_oss.py index 2178bbfe58..1b03ac5fc1 100644 --- a/python/samples/getting_started/sessions/mem0/mem0_oss.py +++ b/python/samples/getting_started/sessions/mem0/mem0_oss.py @@ -10,7 +10,7 @@ from mem0 import AsyncMemory -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def retrieve_company_report(company_code: str, detailed: bool) -> str: if company_code != "CNTS": diff --git a/python/samples/getting_started/sessions/mem0/mem0_sessions.py b/python/samples/getting_started/sessions/mem0/mem0_sessions.py index dd657b4e1d..cc5548e979 100644 --- a/python/samples/getting_started/sessions/mem0/mem0_sessions.py +++ b/python/samples/getting_started/sessions/mem0/mem0_sessions.py @@ -9,7 +9,7 @@ from azure.identity.aio import AzureCliCredential -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_user_preferences(user_id: str) -> str: """Mock function to get user preferences.""" diff --git a/python/samples/getting_started/sessions/redis/redis_basics.py b/python/samples/getting_started/sessions/redis/redis_basics.py index 81238eb171..5f78d65320 100644 --- a/python/samples/getting_started/sessions/redis/redis_basics.py +++ b/python/samples/getting_started/sessions/redis/redis_basics.py @@ -37,7 +37,7 @@ from redisvl.utils.vectorize import OpenAITextVectorizer -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_threads.py. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def search_flights(origin_airport_code: str, destination_airport_code: str, detailed: bool = False) -> str: """Simulated flight-search tool to demonstrate tool memory. From 1f2f645091842cdef96a5c87c906a521b65c4229 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Thu, 12 Feb 2026 21:09:11 +0100 Subject: [PATCH 27/28] =?UTF-8?q?Fix=20broken=20markdown=20links=20after?= =?UTF-8?q?=20thread=E2=86=92session=20file=20renames?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/packages/core/AGENTS.md | 17 ++++++----------- python/packages/redis/README.md | 2 +- .../02-agents/context_providers/mem0/README.md | 2 +- .../02-agents/context_providers/redis/README.md | 2 +- .../02-agents/providers/azure_ai/README.md | 2 +- .../providers/azure_ai_agent/README.md | 4 ++-- .../02-agents/providers/azure_openai/README.md | 6 +++--- .../02-agents/providers/openai/README.md | 6 +++--- python/samples/03-workflows/README.md | 4 ++-- 9 files changed, 20 insertions(+), 25 deletions(-) diff --git a/python/packages/core/AGENTS.md b/python/packages/core/AGENTS.md index 3958957596..a270bc1686 100644 --- a/python/packages/core/AGENTS.md +++ b/python/packages/core/AGENTS.md @@ -12,8 +12,7 @@ agent_framework/ ├── _types.py # Core types (Message, ChatResponse, Content, etc.) ├── _tools.py # Tool definitions and function invocation ├── _middleware.py # Middleware system for request/response interception -├── _threads.py # AgentThread and message store abstractions -├── _memory.py # Context providers for memory/RAG +├── _sessions.py # AgentSession and context provider abstractions ├── _mcp.py # Model Context Protocol support ├── _workflows/ # Workflow orchestration (sequential, concurrent, handoff, etc.) ├── openai/ # Built-in OpenAI client @@ -57,16 +56,12 @@ agent_framework/ - **`FunctionMiddleware`** - Intercepts function/tool invocations - **`AgentContext`** / **`ChatContext`** / **`FunctionInvocationContext`** - Context objects passed through middleware -### Threads (`_threads.py`) +### Sessions (`_sessions.py`) -- **`AgentThread`** - Manages conversation history for an agent -- **`ChatMessageStoreProtocol`** - Protocol for persistent message storage -- **`ChatMessageStore`** - Default in-memory implementation - -### Memory (`_memory.py`) - -- **`ContextProvider`** - Protocol for providing additional context to agents (RAG, memory systems) -- **`Context`** - Container for context data +- **`AgentSession`** - Manages conversation state and session metadata +- **`SessionContext`** - Context object for session-scoped data during agent runs +- **`BaseContextProvider`** - Base class for context providers (RAG, memory systems) +- **`BaseHistoryProvider`** - Base class for conversation history storage ### Workflows (`_workflows/`) diff --git a/python/packages/redis/README.md b/python/packages/redis/README.md index 43ab34c1ee..02ba6f7548 100644 --- a/python/packages/redis/README.md +++ b/python/packages/redis/README.md @@ -30,7 +30,7 @@ The `RedisChatMessageStore` provides persistent conversation storage using Redis #### Basic Usage Examples -See the complete [Redis history provider examples](../../samples/02-agents/conversations/redis_chat_message_store_thread.py) including: +See the complete [Redis history provider examples](../../samples/02-agents/conversations/redis_chat_message_store_session.py) including: - User session management - Conversation persistence across restarts - Session serialization and deserialization diff --git a/python/samples/02-agents/context_providers/mem0/README.md b/python/samples/02-agents/context_providers/mem0/README.md index 61d8bbd51f..4c12bb67d8 100644 --- a/python/samples/02-agents/context_providers/mem0/README.md +++ b/python/samples/02-agents/context_providers/mem0/README.md @@ -9,7 +9,7 @@ This folder contains examples demonstrating how to use the Mem0 context provider | File | Description | |------|-------------| | [`mem0_basic.py`](mem0_basic.py) | Basic example of using Mem0 context provider to store and retrieve user preferences across different conversation threads. | -| [`mem0_threads.py`](mem0_threads.py) | Advanced example demonstrating different thread scoping strategies with Mem0. Covers global thread scope (memories shared across all operations), per-operation thread scope (memories isolated per thread), and multiple agents with different memory configurations for personal vs. work contexts. | +| [`mem0_sessions.py`](mem0_sessions.py) | Advanced example demonstrating different thread scoping strategies with Mem0. Covers global thread scope (memories shared across all operations), per-operation thread scope (memories isolated per thread), and multiple agents with different memory configurations for personal vs. work contexts. | | [`mem0_oss.py`](mem0_oss.py) | Example of using the Mem0 Open Source self-hosted version as the context provider. Demonstrates setup and configuration for local deployment. | ## Prerequisites diff --git a/python/samples/02-agents/context_providers/redis/README.md b/python/samples/02-agents/context_providers/redis/README.md index dec2c77485..b7b25c8d77 100644 --- a/python/samples/02-agents/context_providers/redis/README.md +++ b/python/samples/02-agents/context_providers/redis/README.md @@ -11,7 +11,7 @@ This folder contains an example demonstrating how to use the Redis context provi | [`azure_redis_conversation.py`](azure_redis_conversation.py) | Demonstrates conversation persistence with RedisHistoryProvider and Azure Redis with Azure AD (Entra ID) authentication using credential provider. | | [`redis_basics.py`](redis_basics.py) | Shows standalone provider usage and agent integration. Demonstrates writing messages to Redis, retrieving context via full‑text or hybrid vector search, and persisting preferences across threads. Also includes a simple tool example whose outputs are remembered. | | [`redis_conversation.py`](redis_conversation.py) | Simple example showing conversation persistence with RedisContextProvider using traditional connection string authentication. | -| [`redis_threads.py`](redis_threads.py) | Demonstrates thread scoping. Includes: (1) global thread scope with a fixed `thread_id` shared across operations; (2) per‑operation thread scope where `scope_to_per_operation_thread_id=True` binds memory to a single thread for the provider's lifetime; and (3) multiple agents with isolated memory via different `agent_id` values. | +| [`redis_sessions.py`](redis_sessions.py) | Demonstrates thread scoping. Includes: (1) global thread scope with a fixed `thread_id` shared across operations; (2) per‑operation thread scope where `scope_to_per_operation_thread_id=True` binds memory to a single thread for the provider's lifetime; and (3) multiple agents with isolated memory via different `agent_id` values. | ## Prerequisites diff --git a/python/samples/02-agents/providers/azure_ai/README.md b/python/samples/02-agents/providers/azure_ai/README.md index a047ccb9b0..d49147989f 100644 --- a/python/samples/02-agents/providers/azure_ai/README.md +++ b/python/samples/02-agents/providers/azure_ai/README.md @@ -31,7 +31,7 @@ This folder contains examples demonstrating different ways to create and use age | [`azure_ai_with_search_context_agentic.py`](../../context_providers/azure_ai_search/azure_ai_with_search_context_agentic.py) | Shows how to use AzureAISearchContextProvider with agentic mode. Uses Knowledge Bases for multi-hop reasoning across documents with query planning. Recommended for most scenarios - slightly slower with more token consumption for query planning, but more accurate results. | | [`azure_ai_with_search_context_semantic.py`](../../context_providers/azure_ai_search/azure_ai_with_search_context_semantic.py) | Shows how to use AzureAISearchContextProvider with semantic mode. Fast hybrid search with vector + keyword search and semantic ranking for RAG. Best for simple queries where speed is critical. | | [`azure_ai_with_sharepoint.py`](azure_ai_with_sharepoint.py) | Shows how to use SharePoint grounding with Azure AI agents to search through SharePoint content and answer user questions with proper citations. Requires a SharePoint connection configured in your Azure AI project. | -| [`azure_ai_with_thread.py`](azure_ai_with_thread.py) | Demonstrates session management with Azure AI agents, including automatic session creation for stateless conversations and explicit session management for maintaining conversation context across multiple interactions. | +| [`azure_ai_with_session.py`](azure_ai_with_session.py) | Demonstrates session management with Azure AI agents, including automatic session creation for stateless conversations and explicit session management for maintaining conversation context across multiple interactions. | | [`azure_ai_with_image_generation.py`](azure_ai_with_image_generation.py) | Shows how to use `AzureAIClient.get_image_generation_tool()` with Azure AI agents to generate images based on text prompts. | | [`azure_ai_with_memory_search.py`](azure_ai_with_memory_search.py) | Shows how to use memory search functionality with Azure AI agents for conversation persistence. Demonstrates creating memory stores and enabling agents to search through conversation history. | | [`azure_ai_with_microsoft_fabric.py`](azure_ai_with_microsoft_fabric.py) | Shows how to use Microsoft Fabric with Azure AI agents to query Fabric data sources and provide responses based on data analysis. Requires a Microsoft Fabric connection configured in your Azure AI project. | diff --git a/python/samples/02-agents/providers/azure_ai_agent/README.md b/python/samples/02-agents/providers/azure_ai_agent/README.md index a69572ac9d..3a52984006 100644 --- a/python/samples/02-agents/providers/azure_ai_agent/README.md +++ b/python/samples/02-agents/providers/azure_ai_agent/README.md @@ -38,7 +38,7 @@ async with ( | [`azure_ai_with_code_interpreter_file_generation.py`](azure_ai_with_code_interpreter_file_generation.py) | Shows how to retrieve file IDs from code interpreter generated files using both streaming and non-streaming approaches. | | [`azure_ai_with_code_interpreter.py`](azure_ai_with_code_interpreter.py) | Shows how to use `AzureAIAgentClient.get_code_interpreter_tool()` with Azure AI agents to write and execute Python code. Includes helper methods for accessing code interpreter data from response chunks. | | [`azure_ai_with_existing_agent.py`](azure_ai_with_existing_agent.py) | Shows how to work with an existing SDK Agent object using `provider.as_agent()`. This wraps the agent without making HTTP calls. | -| [`azure_ai_with_existing_thread.py`](azure_ai_with_existing_thread.py) | Shows how to work with a pre-existing session by providing the session ID. Demonstrates proper cleanup of manually created sessions. | +| [`azure_ai_with_existing_session.py`](azure_ai_with_existing_session.py) | Shows how to work with a pre-existing session by providing the session ID. Demonstrates proper cleanup of manually created sessions. | | [`azure_ai_with_explicit_settings.py`](azure_ai_with_explicit_settings.py) | Shows how to create an agent with explicitly configured provider settings, including project endpoint and model deployment name. | | [`azure_ai_with_azure_ai_search.py`](azure_ai_with_azure_ai_search.py) | Demonstrates how to use Azure AI Search with Azure AI agents. Shows how to create an agent with search tools using the SDK directly and wrap it with `provider.get_agent()`. | | [`azure_ai_with_file_search.py`](azure_ai_with_file_search.py) | Demonstrates how to use `AzureAIAgentClient.get_file_search_tool()` with Azure AI agents to search through uploaded documents. Shows file upload, vector store creation, and querying document content. | @@ -48,7 +48,7 @@ async with ( | [`azure_ai_with_multiple_tools.py`](azure_ai_with_multiple_tools.py) | Demonstrates how to use multiple tools together with Azure AI agents, including web search, MCP servers, and function tools using client static methods. Shows coordinated multi-tool interactions and approval workflows. | | [`azure_ai_with_openapi_tools.py`](azure_ai_with_openapi_tools.py) | Demonstrates how to use OpenAPI tools with Azure AI agents to integrate external REST APIs. Shows OpenAPI specification loading, anonymous authentication, session context management, and coordinated multi-API conversations. | | [`azure_ai_with_response_format.py`](azure_ai_with_response_format.py) | Demonstrates how to use structured outputs with Azure AI agents using Pydantic models. | -| [`azure_ai_with_thread.py`](azure_ai_with_thread.py) | Demonstrates session management with Azure AI agents, including automatic session creation for stateless conversations and explicit session management for maintaining conversation context across multiple interactions. | +| [`azure_ai_with_session.py`](azure_ai_with_session.py) | Demonstrates session management with Azure AI agents, including automatic session creation for stateless conversations and explicit session management for maintaining conversation context across multiple interactions. | ## Environment Variables diff --git a/python/samples/02-agents/providers/azure_openai/README.md b/python/samples/02-agents/providers/azure_openai/README.md index 460c2861fe..6971183ccf 100644 --- a/python/samples/02-agents/providers/azure_openai/README.md +++ b/python/samples/02-agents/providers/azure_openai/README.md @@ -11,11 +11,11 @@ This folder contains examples demonstrating different ways to create and use age | [`azure_assistants_with_existing_assistant.py`](azure_assistants_with_existing_assistant.py) | Shows how to work with a pre-existing assistant by providing the assistant ID to the Azure Assistants client. Demonstrates proper cleanup of manually created assistants. | | [`azure_assistants_with_explicit_settings.py`](azure_assistants_with_explicit_settings.py) | Shows how to initialize an agent with a specific assistants client, configuring settings explicitly including endpoint and deployment name. | | [`azure_assistants_with_function_tools.py`](azure_assistants_with_function_tools.py) | Demonstrates how to use function tools with agents. Shows both agent-level tools (defined when creating the agent) and query-level tools (provided with specific queries). | -| [`azure_assistants_with_thread.py`](azure_assistants_with_thread.py) | Demonstrates session management with Azure agents, including automatic session creation for stateless conversations and explicit session management for maintaining conversation context across multiple interactions. | +| [`azure_assistants_with_session.py`](azure_assistants_with_session.py) | Demonstrates session management with Azure agents, including automatic session creation for stateless conversations and explicit session management for maintaining conversation context across multiple interactions. | | [`azure_chat_client_basic.py`](azure_chat_client_basic.py) | The simplest way to create an agent using `Agent` with `AzureOpenAIChatClient`. Shows both streaming and non-streaming responses for chat-based interactions with Azure OpenAI models. | | [`azure_chat_client_with_explicit_settings.py`](azure_chat_client_with_explicit_settings.py) | Shows how to initialize an agent with a specific chat client, configuring settings explicitly including endpoint and deployment name. | | [`azure_chat_client_with_function_tools.py`](azure_chat_client_with_function_tools.py) | Demonstrates how to use function tools with agents. Shows both agent-level tools (defined when creating the agent) and query-level tools (provided with specific queries). | -| [`azure_chat_client_with_thread.py`](azure_chat_client_with_thread.py) | Demonstrates session management with Azure agents, including automatic session creation for stateless conversations and explicit session management for maintaining conversation context across multiple interactions. | +| [`azure_chat_client_with_session.py`](azure_chat_client_with_session.py) | Demonstrates session management with Azure agents, including automatic session creation for stateless conversations and explicit session management for maintaining conversation context across multiple interactions. | | [`azure_responses_client_basic.py`](azure_responses_client_basic.py) | The simplest way to create an agent using `Agent` with `AzureOpenAIResponsesClient`. Shows both streaming and non-streaming responses for structured response generation with Azure OpenAI models. | | [`azure_responses_client_code_interpreter_files.py`](azure_responses_client_code_interpreter_files.py) | Demonstrates using `AzureOpenAIResponsesClient.get_code_interpreter_tool()` with file uploads for data analysis. Shows how to create, upload, and analyze CSV files using Python code execution with Azure OpenAI Responses. | | [`azure_responses_client_image_analysis.py`](azure_responses_client_image_analysis.py) | Shows how to use Azure OpenAI Responses for image analysis and vision tasks. Demonstrates multi-modal messages combining text and image content using remote URLs. | @@ -26,7 +26,7 @@ This folder contains examples demonstrating different ways to create and use age | [`azure_responses_client_with_function_tools.py`](azure_responses_client_with_function_tools.py) | Demonstrates how to use function tools with agents. Shows both agent-level tools (defined when creating the agent) and query-level tools (provided with specific queries). | | [`azure_responses_client_with_hosted_mcp.py`](azure_responses_client_with_hosted_mcp.py) | Shows how to integrate Azure OpenAI Responses Client with hosted Model Context Protocol (MCP) servers using `AzureOpenAIResponsesClient.get_mcp_tool()` for extended functionality. | | [`azure_responses_client_with_local_mcp.py`](azure_responses_client_with_local_mcp.py) | Shows how to integrate Azure OpenAI Responses Client with local Model Context Protocol (MCP) servers using MCPStreamableHTTPTool for extended functionality. | -| [`azure_responses_client_with_thread.py`](azure_responses_client_with_thread.py) | Demonstrates session management with Azure agents, including automatic session creation for stateless conversations and explicit session management for maintaining conversation context across multiple interactions. | +| [`azure_responses_client_with_session.py`](azure_responses_client_with_session.py) | Demonstrates session management with Azure agents, including automatic session creation for stateless conversations and explicit session management for maintaining conversation context across multiple interactions. | ## Environment Variables diff --git a/python/samples/02-agents/providers/openai/README.md b/python/samples/02-agents/providers/openai/README.md index bfcdc94d90..20e757d421 100644 --- a/python/samples/02-agents/providers/openai/README.md +++ b/python/samples/02-agents/providers/openai/README.md @@ -14,12 +14,12 @@ This folder contains examples demonstrating different ways to create and use age | [`openai_assistants_with_file_search.py`](openai_assistants_with_file_search.py) | Using `OpenAIAssistantsClient.get_file_search_tool()` with `OpenAIAssistantProvider` for file search capabilities. | | [`openai_assistants_with_function_tools.py`](openai_assistants_with_function_tools.py) | Function tools with `OpenAIAssistantProvider` at both agent-level and query-level. | | [`openai_assistants_with_response_format.py`](openai_assistants_with_response_format.py) | Structured outputs with `OpenAIAssistantProvider` using Pydantic models. | -| [`openai_assistants_with_thread.py`](openai_assistants_with_thread.py) | Session management with `OpenAIAssistantProvider` for conversation context persistence. | +| [`openai_assistants_with_session.py`](openai_assistants_with_session.py) | Session management with `OpenAIAssistantProvider` for conversation context persistence. | | [`openai_chat_client_basic.py`](openai_chat_client_basic.py) | The simplest way to create an agent using `Agent` with `OpenAIChatClient`. Shows both streaming and non-streaming responses for chat-based interactions with OpenAI models. | | [`openai_chat_client_with_explicit_settings.py`](openai_chat_client_with_explicit_settings.py) | Shows how to initialize an agent with a specific chat client, configuring settings explicitly including API key and model ID. | | [`openai_chat_client_with_function_tools.py`](openai_chat_client_with_function_tools.py) | Demonstrates how to use function tools with agents. Shows both agent-level tools (defined when creating the agent) and query-level tools (provided with specific queries). | | [`openai_chat_client_with_local_mcp.py`](openai_chat_client_with_local_mcp.py) | Shows how to integrate OpenAI agents with local Model Context Protocol (MCP) servers for enhanced functionality and tool integration. | -| [`openai_chat_client_with_thread.py`](openai_chat_client_with_thread.py) | Demonstrates session management with OpenAI agents, including automatic session creation for stateless conversations and explicit session management for maintaining conversation context across multiple interactions. | +| [`openai_chat_client_with_session.py`](openai_chat_client_with_session.py) | Demonstrates session management with OpenAI agents, including automatic session creation for stateless conversations and explicit session management for maintaining conversation context across multiple interactions. | | [`openai_chat_client_with_web_search.py`](openai_chat_client_with_web_search.py) | Shows how to use `OpenAIChatClient.get_web_search_tool()` for web search capabilities with OpenAI agents. | | [`openai_chat_client_with_runtime_json_schema.py`](openai_chat_client_with_runtime_json_schema.py) | Shows how to supply a runtime JSON Schema via `additional_chat_options` for structured output without defining a Pydantic model. | | [`openai_responses_client_basic.py`](openai_responses_client_basic.py) | The simplest way to create an agent using `Agent` with `OpenAIResponsesClient`. Shows both streaming and non-streaming responses for structured response generation with OpenAI models. | @@ -37,7 +37,7 @@ This folder contains examples demonstrating different ways to create and use age | [`openai_responses_client_with_local_mcp.py`](openai_responses_client_with_local_mcp.py) | Shows how to integrate OpenAI agents with local Model Context Protocol (MCP) servers for enhanced functionality and tool integration. | | [`openai_responses_client_with_runtime_json_schema.py`](openai_responses_client_with_runtime_json_schema.py) | Shows how to supply a runtime JSON Schema via `additional_chat_options` for structured output without defining a Pydantic model. | | [`openai_responses_client_with_structured_output.py`](openai_responses_client_with_structured_output.py) | Demonstrates how to use structured outputs with OpenAI agents to get structured data responses in predefined formats. | -| [`openai_responses_client_with_thread.py`](openai_responses_client_with_thread.py) | Demonstrates session management with OpenAI agents, including automatic session creation for stateless conversations and explicit session management for maintaining conversation context across multiple interactions. | +| [`openai_responses_client_with_session.py`](openai_responses_client_with_session.py) | Demonstrates session management with OpenAI agents, including automatic session creation for stateless conversations and explicit session management for maintaining conversation context across multiple interactions. | | [`openai_responses_client_with_web_search.py`](openai_responses_client_with_web_search.py) | Shows how to use `OpenAIResponsesClient.get_web_search_tool()` for web search capabilities. | ## Environment Variables diff --git a/python/samples/03-workflows/README.md b/python/samples/03-workflows/README.md index 26eccd03e4..b72cdce54d 100644 --- a/python/samples/03-workflows/README.md +++ b/python/samples/03-workflows/README.md @@ -36,11 +36,11 @@ Once comfortable with these, explore the rest of the samples below. | -------------------------------------- | -------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------- | | Azure Chat Agents (Streaming) | [agents/azure_chat_agents_streaming.py](./agents/azure_chat_agents_streaming.py) | Add Azure Chat agents as edges and handle streaming events | | Azure AI Agents (Streaming) | [agents/azure_ai_agents_streaming.py](./agents/azure_ai_agents_streaming.py) | Add Azure AI agents as edges and handle streaming events | -| Azure AI Agents (Shared Thread) | [agents/azure_ai_agents_with_shared_thread.py](./agents/azure_ai_agents_with_shared_thread.py) | Share a common message session between multiple Azure AI agents in a workflow | +| Azure AI Agents (Shared Thread) | [agents/azure_ai_agents_with_shared_session.py](./agents/azure_ai_agents_with_shared_session.py) | Share a common message session between multiple Azure AI agents in a workflow | | Custom Agent Executors | [agents/custom_agent_executors.py](./agents/custom_agent_executors.py) | Create executors to handle agent run methods | | Workflow as Agent (Reflection Pattern) | [agents/workflow_as_agent_reflection_pattern.py](./agents/workflow_as_agent_reflection_pattern.py) | Wrap a workflow so it can behave like an agent (reflection pattern) | | Workflow as Agent + HITL | [agents/workflow_as_agent_human_in_the_loop.py](./agents/workflow_as_agent_human_in_the_loop.py) | Extend workflow-as-agent with human-in-the-loop capability | -| Workflow as Agent with Session | [agents/workflow_as_agent_with_thread.py](./agents/workflow_as_agent_with_thread.py) | Use AgentSession to maintain conversation history across workflow-as-agent invocations | +| Workflow as Agent with Session | [agents/workflow_as_agent_with_session.py](./agents/workflow_as_agent_with_session.py) | Use AgentSession to maintain conversation history across workflow-as-agent invocations | | Workflow as Agent kwargs | [agents/workflow_as_agent_kwargs.py](./agents/workflow_as_agent_kwargs.py) | Pass custom context (data, user tokens) via kwargs through workflow.as_agent() to @ai_function tools | ### checkpoint From 3d7aa5ffc1521401d69223073c5c1ac0ee9a54d6 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Thu, 12 Feb 2026 21:55:22 +0100 Subject: [PATCH 28/28] fix azure ai test --- .../azure-ai/tests/test_azure_ai_client.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/python/packages/azure-ai/tests/test_azure_ai_client.py b/python/packages/azure-ai/tests/test_azure_ai_client.py index abcb2a5bda..1114747d1b 100644 --- a/python/packages/azure-ai/tests/test_azure_ai_client.py +++ b/python/packages/azure-ai/tests/test_azure_ai_client.py @@ -1473,39 +1473,39 @@ async def test_integration_agent_hosted_code_interpreter_tool(): @pytest.mark.flaky @skip_if_azure_ai_integration_tests_disabled -async def test_integration_agent_existing_thread(): - """Test Azure Responses Client agent with existing thread to continue conversations across agent instances.""" - # First conversation - capture the thread - preserved_thread = None +async def test_integration_agent_existing_session(): + """Test Azure Responses Client agent with existing session to continue conversations across agent instances.""" + # First conversation - capture the session + preserved_session = None async with ( - temporary_chat_client(agent_name="af-int-test-existing-thread") as client, + temporary_chat_client(agent_name="af-int-test-existing-session") as client, Agent( client=client, instructions="You are a helpful assistant with good memory.", ) as first_agent, ): - # Start a conversation and capture the thread - thread = first_agent.get_new_thread() - first_response = await first_agent.run("My hobby is photography. Remember this.", thread=thread, store=True) + # Start a conversation and capture the session + session = first_agent.create_session() + first_response = await first_agent.run("My hobby is photography. Remember this.", session=session, store=True) assert isinstance(first_response, AgentResponse) assert first_response.text is not None - # Preserve the thread for reuse - preserved_thread = thread + # Preserve the session for reuse + preserved_session = session - # Second conversation - reuse the thread in a new agent instance - if preserved_thread: + # Second conversation - reuse the session in a new agent instance + if preserved_session: async with ( - temporary_chat_client(agent_name="af-int-test-existing-thread-2") as client, + temporary_chat_client(agent_name="af-int-test-existing-session-2") as client, Agent( client=client, instructions="You are a helpful assistant with good memory.", ) as second_agent, ): - # Reuse the preserved thread - second_response = await second_agent.run("What is my hobby?", thread=preserved_thread) + # Reuse the preserved session + second_response = await second_agent.run("What is my hobby?", session=preserved_session) assert isinstance(second_response, AgentResponse) assert second_response.text is not None