Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/packages/core/agent_framework/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
Content,
Message,
ResponseStream,
UsageDetails,
)

else:
Expand Down Expand Up @@ -2095,6 +2096,7 @@ def get_response(
ChatResponse,
ChatResponseUpdate,
ResponseStream,
add_usage_details,
)

super_get_response = super().get_response # type: ignore[misc]
Expand Down Expand Up @@ -2160,6 +2162,7 @@ async def _get_response() -> ChatResponse[Any]:
prepped_messages = list(messages)
fcc_messages: list[Message] = []
response: ChatResponse[Any] | None = None
aggregated_usage: UsageDetails | None = None

loop_enabled = self.function_invocation_configuration.get("enabled", True)
max_iterations = self.function_invocation_configuration.get("max_iterations", DEFAULT_MAX_ITERATIONS)
Expand Down Expand Up @@ -2191,6 +2194,7 @@ async def _get_response() -> ChatResponse[Any]:
client_kwargs=filtered_kwargs,
),
)
aggregated_usage = add_usage_details(aggregated_usage, response.usage_details)

if response.conversation_id is not None:
_update_conversation_id(kwargs, response.conversation_id, mutable_options)
Expand All @@ -2207,6 +2211,7 @@ async def _get_response() -> ChatResponse[Any]:
execute_function_calls=execute_function_calls,
)
if result.get("action") == "return":
response.usage_details = aggregated_usage
return response
total_function_calls += result.get("function_call_count", 0)
if result.get("action") == "stop":
Expand Down Expand Up @@ -2262,6 +2267,8 @@ async def _get_response() -> ChatResponse[Any]:
client_kwargs=filtered_kwargs,
),
)
aggregated_usage = add_usage_details(aggregated_usage, response.usage_details)
response.usage_details = aggregated_usage
if fcc_messages:
for msg in reversed(fcc_messages):
response.messages.insert(0, msg)
Expand Down
141 changes: 141 additions & 0 deletions python/packages/core/tests/core/test_observability.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ChatResponseUpdate,
Content,
Message,
RawAgent,
ResponseStream,
SupportsAgentRun,
UsageDetails,
Expand Down Expand Up @@ -2781,3 +2782,143 @@ def mock_get_meter(*args, **kwargs):
meter = get_meter(name="test", attributes={"key": "val"})
assert meter is not None
assert call_count == 2


# region Agent token usage aggregation


@tool(name="get_weather", description="Get weather for a city", approval_mode="never_require")
def _get_weather(city: str) -> str:
"""Get weather for a city."""
return "Sunny, 72°F"


@pytest.mark.parametrize("enable_sensitive_data", [False], indirect=True)
async def test_agent_invoke_span_aggregates_usage_across_tool_calls(span_exporter: InMemorySpanExporter):
"""The invoke_agent span should sum token usage from all chat completions in the function invocation loop."""
from tests.core.conftest import MockBaseChatClient

class _InstrumentedAgent(AgentTelemetryLayer, RawAgent):
pass

client = MockBaseChatClient()
client.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(call_id="call_1", name="get_weather", arguments='{"city": "Seattle"}')
],
),
usage_details=UsageDetails(input_token_count=2239, output_token_count=192),
),
ChatResponse(
messages=Message(role="assistant", text="The weather in Seattle is sunny."),
usage_details=UsageDetails(input_token_count=2569, output_token_count=99),
),
]

agent = _InstrumentedAgent(client=client, name="test_agent", id="test_agent_id")

span_exporter.clear()
await agent.run(
messages="What is the weather in Seattle?",
options={"tools": [_get_weather], "tool_choice": "auto"},
)

spans = span_exporter.get_finished_spans()

invoke_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.AGENT_INVOKE_OPERATION]
assert len(invoke_spans) == 1
agent_span = invoke_spans[0]

chat_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.CHAT_COMPLETION_OPERATION]
assert len(chat_spans) == 2

# Individual chat spans retain their own usage
assert chat_spans[0].attributes.get(OtelAttr.INPUT_TOKENS) == 2239
assert chat_spans[0].attributes.get(OtelAttr.OUTPUT_TOKENS) == 192
assert chat_spans[1].attributes.get(OtelAttr.INPUT_TOKENS) == 2569
assert chat_spans[1].attributes.get(OtelAttr.OUTPUT_TOKENS) == 99

# The invoke_agent span must report the aggregate across all LLM round-trips
assert agent_span.attributes.get(OtelAttr.INPUT_TOKENS) == 2239 + 2569
assert agent_span.attributes.get(OtelAttr.OUTPUT_TOKENS) == 192 + 99


@pytest.mark.parametrize("enable_sensitive_data", [False], indirect=True)
async def test_agent_invoke_span_usage_single_call(span_exporter: InMemorySpanExporter):
"""When only one chat completion occurs, the invoke_agent span usage equals that single call."""
from tests.core.conftest import MockBaseChatClient

class _InstrumentedAgent(AgentTelemetryLayer, RawAgent):
pass

client = MockBaseChatClient()
client.run_responses = [
ChatResponse(
messages=Message(role="assistant", text="Hello!"),
usage_details=UsageDetails(input_token_count=100, output_token_count=50),
),
]

agent = _InstrumentedAgent(client=client, name="test_agent", id="test_agent_id")

span_exporter.clear()
await agent.run(messages="Hi")

spans = span_exporter.get_finished_spans()
invoke_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.AGENT_INVOKE_OPERATION]
assert len(invoke_spans) == 1

assert invoke_spans[0].attributes.get(OtelAttr.INPUT_TOKENS) == 100
assert invoke_spans[0].attributes.get(OtelAttr.OUTPUT_TOKENS) == 50


@pytest.mark.parametrize("enable_sensitive_data", [False], indirect=True)
async def test_agent_invoke_span_aggregates_usage_on_max_iterations_exhaustion(span_exporter: InMemorySpanExporter):
"""When the function invocation loop exhausts max_iterations, the final response aggregates usage
from all rounds."""
from tests.core.conftest import MockBaseChatClient

class _InstrumentedAgent(AgentTelemetryLayer, RawAgent):
pass

client = MockBaseChatClient(
function_invocation_configuration={"max_iterations": 1},
)
client.run_responses = [
# Iteration 0: model returns a tool call
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(call_id="call_1", name="get_weather", arguments='{"city": "Seattle"}')
],
),
usage_details=UsageDetails(input_token_count=500, output_token_count=100),
),
# Exhaustion path: consumed by tool_choice="none" final call (mock ignores usage)
ChatResponse(
messages=Message(role="assistant", text="placeholder"),
usage_details=UsageDetails(input_token_count=300, output_token_count=60),
),
]

agent = _InstrumentedAgent(client=client, name="test_agent", id="test_agent_id")

span_exporter.clear()
await agent.run(
messages="What is the weather in Seattle?",
options={"tools": [_get_weather], "tool_choice": "auto"},
)

spans = span_exporter.get_finished_spans()

invoke_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.AGENT_INVOKE_OPERATION]
assert len(invoke_spans) == 1
agent_span = invoke_spans[0]

# The invoke_agent span must aggregate usage from the in-loop call and the final exhaustion call
assert agent_span.attributes.get(OtelAttr.INPUT_TOKENS) == 500
assert agent_span.attributes.get(OtelAttr.OUTPUT_TOKENS) == 100
Loading