-
Notifications
You must be signed in to change notification settings - Fork 769
feat: naive token estimation via tiktoken #2031
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| """Abstract base class for Agent model providers.""" | ||
|
|
||
| import abc | ||
| import json | ||
| import logging | ||
| from collections.abc import AsyncGenerator, AsyncIterable | ||
| from dataclasses import dataclass | ||
|
|
@@ -10,7 +11,7 @@ | |
|
|
||
| from ..hooks.events import AfterInvocationEvent | ||
| from ..plugins.plugin import Plugin | ||
| from ..types.content import Messages, SystemContentBlock | ||
| from ..types.content import ContentBlock, Messages, SystemContentBlock | ||
| from ..types.streaming import StreamEvent | ||
| from ..types.tools import ToolChoice, ToolSpec | ||
|
|
||
|
|
@@ -21,6 +22,110 @@ | |
|
|
||
| T = TypeVar("T", bound=BaseModel) | ||
|
|
||
| _DEFAULT_ENCODING = "cl100k_base" | ||
| _cached_encoding: Any = None | ||
|
|
||
|
|
||
| def _get_encoding() -> Any: | ||
| """Get the default tiktoken encoding, caching to avoid repeated lookups.""" | ||
| global _cached_encoding | ||
| if _cached_encoding is None: | ||
| try: | ||
| import tiktoken | ||
| except ImportError as err: | ||
| raise ImportError( | ||
| "tiktoken is required for token estimation. " | ||
| "Install it with: pip install strands-agents[token-estimation]" | ||
| ) from err | ||
| _cached_encoding = tiktoken.get_encoding(_DEFAULT_ENCODING) | ||
| return _cached_encoding | ||
|
|
||
|
|
||
| def _count_content_block_tokens(block: ContentBlock, encoding: Any) -> int: | ||
| """Count tokens for a single content block.""" | ||
| total = 0 | ||
|
|
||
| if "text" in block: | ||
| total += len(encoding.encode(block["text"])) | ||
|
|
||
| if "toolUse" in block: | ||
| tool_use = block["toolUse"] | ||
| total += len(encoding.encode(tool_use.get("name", ""))) | ||
| try: | ||
| total += len(encoding.encode(json.dumps(tool_use.get("input", {})))) | ||
| except (TypeError, ValueError): | ||
| logger.debug( | ||
| "tool_name=<%s> | skipping non-serializable toolUse input for token estimation", | ||
| tool_use.get("name", "unknown"), | ||
| ) | ||
|
|
||
| if "toolResult" in block: | ||
| tool_result = block["toolResult"] | ||
| for item in tool_result.get("content", []): | ||
| if "text" in item: | ||
| total += len(encoding.encode(item["text"])) | ||
|
|
||
| if "reasoningContent" in block: | ||
| reasoning = block["reasoningContent"] | ||
| if "reasoningText" in reasoning: | ||
| reasoning_text = reasoning["reasoningText"] | ||
| if "text" in reasoning_text: | ||
| total += len(encoding.encode(reasoning_text["text"])) | ||
|
|
||
| if "guardContent" in block: | ||
| guard = block["guardContent"] | ||
| if "text" in guard and "text" in guard["text"]: | ||
| total += len(encoding.encode(guard["text"]["text"])) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If |
||
|
|
||
| if "citationsContent" in block: | ||
| citations = block["citationsContent"] | ||
| if "content" in citations: | ||
| for citation_item in citations["content"]: | ||
| if "text" in citation_item: | ||
| total += len(encoding.encode(citation_item["text"])) | ||
|
|
||
| return total | ||
|
|
||
|
|
||
| def _estimate_tokens_with_tiktoken( | ||
| messages: Messages, | ||
| tool_specs: list[ToolSpec] | None = None, | ||
| system_prompt: str | None = None, | ||
| system_prompt_content: list[SystemContentBlock] | None = None, | ||
| ) -> int: | ||
| """Estimate tokens by serializing messages/tools to text and counting with tiktoken. | ||
|
|
||
| This is a best-effort fallback for providers that don't expose native counting. | ||
| Accuracy varies by model but is sufficient for threshold-based decisions. | ||
| """ | ||
| encoding = _get_encoding() | ||
| total = 0 | ||
|
|
||
| # Prefer system_prompt_content (structured) over system_prompt (plain string) to avoid double-counting, | ||
| # since providers wrap system_prompt into system_prompt_content when both are provided. | ||
| if system_prompt_content: | ||
| for block in system_prompt_content: | ||
| if "text" in block: | ||
| total += len(encoding.encode(block["text"])) | ||
| elif system_prompt: | ||
| total += len(encoding.encode(system_prompt)) | ||
|
|
||
| for message in messages: | ||
| for block in message["content"]: | ||
| total += _count_content_block_tokens(block, encoding) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit, one trick we can do to improve accuracy: instead of getting all of the token count for messages array, just keep track of the latest consumed tokens, and just estimate latest one. So then your error margin for the history is 0% (bc we literally know the token count), and the only error happens in the latest added message
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like this suggestion. There's separate work going on to expose the latest token count which makes this possible. Once that's set up we can implement this as a follow-up optimization |
||
|
|
||
| if tool_specs: | ||
| for spec in tool_specs: | ||
| try: | ||
| total += len(encoding.encode(json.dumps(spec))) | ||
| except (TypeError, ValueError): | ||
| logger.debug( | ||
| "tool_name=<%s> | skipping non-serializable tool spec for token estimation", | ||
| spec.get("name", "unknown"), | ||
| ) | ||
|
|
||
| return total | ||
|
|
||
|
|
||
| @dataclass | ||
| class CacheConfig: | ||
|
|
@@ -130,6 +235,34 @@ def stream( | |
| """ | ||
| pass | ||
|
|
||
| def _estimate_tokens( | ||
| self, | ||
| messages: Messages, | ||
| tool_specs: list[ToolSpec] | None = None, | ||
| system_prompt: str | None = None, | ||
| system_prompt_content: list[SystemContentBlock] | None = None, | ||
| ) -> int: | ||
| """Estimate token count for the given input before sending to the model. | ||
|
|
||
| Used for proactive context management (e.g., triggering compression at a | ||
| threshold). This is a naive approximation using tiktoken's cl100k_base encoding. | ||
| Accuracy varies by model provider but is estimated to be within 5-15% for most providers. | ||
| Not intended for billing or precise quota calculations. | ||
|
|
||
| Subclasses may override this method to provide model-specific token counting | ||
| using native APIs for improved accuracy. | ||
|
|
||
| Args: | ||
| messages: List of message objects to estimate tokens for. | ||
| tool_specs: List of tool specifications to include in the estimate. | ||
| system_prompt: Plain string system prompt. Ignored if system_prompt_content is provided. | ||
| system_prompt_content: Structured system prompt content blocks. Takes priority over system_prompt. | ||
|
|
||
| Returns: | ||
| Estimated total input tokens. | ||
| """ | ||
| return _estimate_tokens_with_tiktoken(messages, tool_specs, system_prompt, system_prompt_content) | ||
|
|
||
|
|
||
| class _ModelPlugin(Plugin): | ||
| """Plugin that manages model-related lifecycle hooks.""" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get caching, but why do we keep importing inside the method? Is this intentionally lazy loading?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe token estimation should be it's own file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is intentional, since
tiktokenis an optional dependency