diff --git a/sentry_sdk/ai/span_config.py b/sentry_sdk/ai/span_config.py new file mode 100644 index 0000000000..0520c87de4 --- /dev/null +++ b/sentry_sdk/ai/span_config.py @@ -0,0 +1,88 @@ +import sentry_sdk +from sentry_sdk.consts import SPANDATA +from sentry_sdk.ai.monitoring import record_token_usage +from sentry_sdk.ai.utils import ( + get_first_from_sources, + set_data_normalized, + set_span_data_from_sources, + normalize_message_roles, + truncate_and_annotate_messages, +) +from sentry_sdk.scope import should_send_default_pii + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any, Dict, List, Optional + + from sentry_sdk.tracing import Span + + +def set_request_span_data(span, kwargs, integration, config, span_data=None): + # type: (Span, Dict[str, Any], Any, Dict[str, Any], Dict[str, Any] | None) -> None + """Set request/static span data from a declarative config.""" + for key, value in config.get("static", {}).items(): + set_data_normalized(span, key, value) + if span_data: + for key, value in span_data.items(): + set_data_normalized(span, key, value) + + for kwarg_key, span_attr in config.get("params", {}).items(): + if kwarg_key in kwargs: + value = kwargs[kwarg_key] + set_data_normalized(span, span_attr, value) + + if should_send_default_pii() and integration.include_prompts: + for kwarg_key, span_attr in config.get("pii_params", {}).items(): + if kwarg_key in kwargs: + value = kwargs[kwarg_key] + set_data_normalized(span, span_attr, value) + + +def set_request_messages(span, messages, target=None): + # type: (Span, Any, Optional[str]) -> None + """Normalize, truncate, and set request messages on the span. + + Caller is responsible for PII gating. + """ + if not messages: + return + messages = normalize_message_roles(messages) + scope = sentry_sdk.get_current_scope() + messages = truncate_and_annotate_messages(messages, span, scope) + if messages is not None: + set_data_normalized( + span, target or SPANDATA.GEN_AI_REQUEST_MESSAGES, messages, unpack=False + ) + + +def set_response_span_data( + span, response, include_pii, response_config, response_text=None +): + # type: (Span, Any, bool, Dict[str, Any], Optional[List[str]]) -> None + """Set response span data from a declarative config.""" + set_span_data_from_sources( + span, response, response_config.get("sources", {}), require_truthy=False + ) + + if include_pii: + pii_sources = response_config.get("pii_sources") + if pii_sources: + set_span_data_from_sources(span, response, pii_sources, require_truthy=True) + if response_text: + set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, response_text) + + usage_config = response_config.get("usage") + if usage_config: + record_token_usage( + span, + input_tokens=get_first_from_sources( + response, usage_config.get("input_tokens", []) + ), + output_tokens=get_first_from_sources( + response, usage_config.get("output_tokens", []) + ), + total_tokens=get_first_from_sources( + response, usage_config.get("total_tokens", []) + ), + ) diff --git a/sentry_sdk/ai/utils.py b/sentry_sdk/ai/utils.py index 5acc501172..a05a068097 100644 --- a/sentry_sdk/ai/utils.py +++ b/sentry_sdk/ai/utils.py @@ -8,7 +8,7 @@ from sentry_sdk._types import BLOB_DATA_SUBSTITUTE if TYPE_CHECKING: - from typing import Any, Callable, Dict, List, Optional, Tuple + from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple from sentry_sdk.tracing import Span @@ -30,7 +30,7 @@ class GEN_AI_ALLOWED_MESSAGE_ROLES: GEN_AI_MESSAGE_ROLE_REVERSE_MAPPING = { GEN_AI_ALLOWED_MESSAGE_ROLES.SYSTEM: ["system"], GEN_AI_ALLOWED_MESSAGE_ROLES.USER: ["user", "human"], - GEN_AI_ALLOWED_MESSAGE_ROLES.ASSISTANT: ["assistant", "ai"], + GEN_AI_ALLOWED_MESSAGE_ROLES.ASSISTANT: ["assistant", "ai", "chatbot"], GEN_AI_ALLOWED_MESSAGE_ROLES.TOOL: ["tool", "tool_call"], } @@ -725,3 +725,32 @@ def set_conversation_id(conversation_id: str) -> None: """ scope = sentry_sdk.get_current_scope() scope.set_conversation_id(conversation_id) + + +def transitive_getattr(obj, *attrs): + # type: (Any, str) -> Any + current = obj + for attr in attrs: + current = getattr(current, attr, None) + if current is None: + return None + return current + + +def get_first_from_sources(obj, source_paths, require_truthy=False): + # type: (Any, Sequence[tuple[str, ...]], bool) -> Any + for source_path in source_paths: + value = transitive_getattr(obj, *source_path) + if not value: + continue + if not require_truthy or value: + return value + return None + + +def set_span_data_from_sources(span, obj, target_sources, require_truthy): + # type: (Any, Any, Mapping[str, Sequence[tuple[str, ...]]], bool) -> None + for spandata_key, source_paths in target_sources.items(): + value = get_first_from_sources(obj, source_paths, require_truthy=require_truthy) + if value is not None: + set_data_normalized(span, spandata_key, value) diff --git a/sentry_sdk/integrations/cohere.py b/sentry_sdk/integrations/cohere.py deleted file mode 100644 index f45a02f2b5..0000000000 --- a/sentry_sdk/integrations/cohere.py +++ /dev/null @@ -1,269 +0,0 @@ -import sys -from functools import wraps - -from sentry_sdk import consts -from sentry_sdk.ai.monitoring import record_token_usage -from sentry_sdk.consts import SPANDATA -from sentry_sdk.ai.utils import set_data_normalized - -from typing import TYPE_CHECKING - -from sentry_sdk.tracing_utils import set_span_errored - -if TYPE_CHECKING: - from typing import Any, Callable, Iterator - from sentry_sdk.tracing import Span - -import sentry_sdk -from sentry_sdk.scope import should_send_default_pii -from sentry_sdk.integrations import DidNotEnable, Integration -from sentry_sdk.utils import capture_internal_exceptions, event_from_exception, reraise - -try: - from cohere.client import Client - from cohere.base_client import BaseCohere - from cohere import ( - ChatStreamEndEvent, - NonStreamedChatResponse, - ) - - if TYPE_CHECKING: - from cohere import StreamedChatResponse -except ImportError: - raise DidNotEnable("Cohere not installed") - -try: - # cohere 5.9.3+ - from cohere import StreamEndStreamedChatResponse -except ImportError: - from cohere import StreamedChatResponse_StreamEnd as StreamEndStreamedChatResponse - - -COLLECTED_CHAT_PARAMS = { - "model": SPANDATA.AI_MODEL_ID, - "k": SPANDATA.AI_TOP_K, - "p": SPANDATA.AI_TOP_P, - "seed": SPANDATA.AI_SEED, - "frequency_penalty": SPANDATA.AI_FREQUENCY_PENALTY, - "presence_penalty": SPANDATA.AI_PRESENCE_PENALTY, - "raw_prompting": SPANDATA.AI_RAW_PROMPTING, -} - -COLLECTED_PII_CHAT_PARAMS = { - "tools": SPANDATA.AI_TOOLS, - "preamble": SPANDATA.AI_PREAMBLE, -} - -COLLECTED_CHAT_RESP_ATTRS = { - "generation_id": SPANDATA.AI_GENERATION_ID, - "is_search_required": SPANDATA.AI_SEARCH_REQUIRED, - "finish_reason": SPANDATA.AI_FINISH_REASON, -} - -COLLECTED_PII_CHAT_RESP_ATTRS = { - "citations": SPANDATA.AI_CITATIONS, - "documents": SPANDATA.AI_DOCUMENTS, - "search_queries": SPANDATA.AI_SEARCH_QUERIES, - "search_results": SPANDATA.AI_SEARCH_RESULTS, - "tool_calls": SPANDATA.AI_TOOL_CALLS, -} - - -class CohereIntegration(Integration): - identifier = "cohere" - origin = f"auto.ai.{identifier}" - - def __init__(self: "CohereIntegration", include_prompts: bool = True) -> None: - self.include_prompts = include_prompts - - @staticmethod - def setup_once() -> None: - BaseCohere.chat = _wrap_chat(BaseCohere.chat, streaming=False) - Client.embed = _wrap_embed(Client.embed) - BaseCohere.chat_stream = _wrap_chat(BaseCohere.chat_stream, streaming=True) - - -def _capture_exception(exc: "Any") -> None: - set_span_errored() - - event, hint = event_from_exception( - exc, - client_options=sentry_sdk.get_client().options, - mechanism={"type": "cohere", "handled": False}, - ) - sentry_sdk.capture_event(event, hint=hint) - - -def _wrap_chat(f: "Callable[..., Any]", streaming: bool) -> "Callable[..., Any]": - def collect_chat_response_fields( - span: "Span", res: "NonStreamedChatResponse", include_pii: bool - ) -> None: - if include_pii: - if hasattr(res, "text"): - set_data_normalized( - span, - SPANDATA.AI_RESPONSES, - [res.text], - ) - for pii_attr in COLLECTED_PII_CHAT_RESP_ATTRS: - if hasattr(res, pii_attr): - set_data_normalized(span, "ai." + pii_attr, getattr(res, pii_attr)) - - for attr in COLLECTED_CHAT_RESP_ATTRS: - if hasattr(res, attr): - set_data_normalized(span, "ai." + attr, getattr(res, attr)) - - if hasattr(res, "meta"): - if hasattr(res.meta, "billed_units"): - record_token_usage( - span, - input_tokens=res.meta.billed_units.input_tokens, - output_tokens=res.meta.billed_units.output_tokens, - ) - elif hasattr(res.meta, "tokens"): - record_token_usage( - span, - input_tokens=res.meta.tokens.input_tokens, - output_tokens=res.meta.tokens.output_tokens, - ) - - if hasattr(res.meta, "warnings"): - set_data_normalized(span, SPANDATA.AI_WARNINGS, res.meta.warnings) - - @wraps(f) - def new_chat(*args: "Any", **kwargs: "Any") -> "Any": - integration = sentry_sdk.get_client().get_integration(CohereIntegration) - - if ( - integration is None - or "message" not in kwargs - or not isinstance(kwargs.get("message"), str) - ): - return f(*args, **kwargs) - - message = kwargs.get("message") - - span = sentry_sdk.start_span( - op=consts.OP.COHERE_CHAT_COMPLETIONS_CREATE, - name="cohere.client.Chat", - origin=CohereIntegration.origin, - ) - span.__enter__() - try: - res = f(*args, **kwargs) - except Exception as e: - exc_info = sys.exc_info() - with capture_internal_exceptions(): - _capture_exception(e) - span.__exit__(None, None, None) - reraise(*exc_info) - - with capture_internal_exceptions(): - if should_send_default_pii() and integration.include_prompts: - set_data_normalized( - span, - SPANDATA.AI_INPUT_MESSAGES, - list( - map( - lambda x: { - "role": getattr(x, "role", "").lower(), - "content": getattr(x, "message", ""), - }, - kwargs.get("chat_history", []), - ) - ) - + [{"role": "user", "content": message}], - ) - for k, v in COLLECTED_PII_CHAT_PARAMS.items(): - if k in kwargs: - set_data_normalized(span, v, kwargs[k]) - - for k, v in COLLECTED_CHAT_PARAMS.items(): - if k in kwargs: - set_data_normalized(span, v, kwargs[k]) - set_data_normalized(span, SPANDATA.AI_STREAMING, False) - - if streaming: - old_iterator = res - - def new_iterator() -> "Iterator[StreamedChatResponse]": - with capture_internal_exceptions(): - for x in old_iterator: - if isinstance(x, ChatStreamEndEvent) or isinstance( - x, StreamEndStreamedChatResponse - ): - collect_chat_response_fields( - span, - x.response, - include_pii=should_send_default_pii() - and integration.include_prompts, - ) - yield x - - span.__exit__(None, None, None) - - return new_iterator() - elif isinstance(res, NonStreamedChatResponse): - collect_chat_response_fields( - span, - res, - include_pii=should_send_default_pii() - and integration.include_prompts, - ) - span.__exit__(None, None, None) - else: - set_data_normalized(span, "unknown_response", True) - span.__exit__(None, None, None) - return res - - return new_chat - - -def _wrap_embed(f: "Callable[..., Any]") -> "Callable[..., Any]": - @wraps(f) - def new_embed(*args: "Any", **kwargs: "Any") -> "Any": - integration = sentry_sdk.get_client().get_integration(CohereIntegration) - if integration is None: - return f(*args, **kwargs) - - with sentry_sdk.start_span( - op=consts.OP.COHERE_EMBEDDINGS_CREATE, - name="Cohere Embedding Creation", - origin=CohereIntegration.origin, - ) as span: - if "texts" in kwargs and ( - should_send_default_pii() and integration.include_prompts - ): - if isinstance(kwargs["texts"], str): - set_data_normalized(span, SPANDATA.AI_TEXTS, [kwargs["texts"]]) - elif ( - isinstance(kwargs["texts"], list) - and len(kwargs["texts"]) > 0 - and isinstance(kwargs["texts"][0], str) - ): - set_data_normalized( - span, SPANDATA.AI_INPUT_MESSAGES, kwargs["texts"] - ) - - if "model" in kwargs: - set_data_normalized(span, SPANDATA.AI_MODEL_ID, kwargs["model"]) - try: - res = f(*args, **kwargs) - except Exception as e: - exc_info = sys.exc_info() - with capture_internal_exceptions(): - _capture_exception(e) - reraise(*exc_info) - if ( - hasattr(res, "meta") - and hasattr(res.meta, "billed_units") - and hasattr(res.meta.billed_units, "input_tokens") - ): - record_token_usage( - span, - input_tokens=res.meta.billed_units.input_tokens, - total_tokens=res.meta.billed_units.input_tokens, - ) - return res - - return new_embed diff --git a/sentry_sdk/integrations/cohere/__init__.py b/sentry_sdk/integrations/cohere/__init__.py new file mode 100644 index 0000000000..01d1f2e53c --- /dev/null +++ b/sentry_sdk/integrations/cohere/__init__.py @@ -0,0 +1,109 @@ +import sys +from functools import wraps + +from sentry_sdk.consts import OP, SPANDATA +from sentry_sdk.ai.span_config import ( + set_request_span_data, + set_request_messages, + set_response_span_data, +) +from sentry_sdk.integrations.cohere.configs import COHERE_EMBED_CONFIG + +from typing import TYPE_CHECKING + +from sentry_sdk.scope import should_send_default_pii +from sentry_sdk.tracing_utils import set_span_errored + +if TYPE_CHECKING: + from typing import Any, Callable + +import sentry_sdk +from sentry_sdk.integrations import DidNotEnable, Integration +from sentry_sdk.utils import capture_internal_exceptions, event_from_exception, reraise + +try: + from cohere import __version__ as cohere_version # noqa: F401 +except ImportError: + raise DidNotEnable("Cohere not installed") + + +class CohereIntegration(Integration): + identifier = "cohere" + origin = f"auto.ai.{identifier}" + + def __init__(self, include_prompts=True): + # type: (bool) -> None + self.include_prompts = include_prompts + + @staticmethod + def setup_once(): + # type: () -> None + # Lazy imports to avoid circular dependencies: + from sentry_sdk.integrations.cohere.v1 import setup_v1 + from sentry_sdk.integrations.cohere.v2 import setup_v2 + + setup_v1(_wrap_embed) + setup_v2(_wrap_embed) + + +def _capture_exception(exc): + # type: (Any) -> None + set_span_errored() + + event, hint = event_from_exception( + exc, + client_options=sentry_sdk.get_client().options, + mechanism={"type": "cohere", "handled": False}, + ) + sentry_sdk.capture_event(event, hint=hint) + + +def _wrap_embed(f): + # type: (Callable[..., Any]) -> Callable[..., Any] + @wraps(f) + def new_embed(*args, **kwargs): + # type: (*Any, **Any) -> Any + integration = sentry_sdk.get_client().get_integration(CohereIntegration) + if integration is None: + return f(*args, **kwargs) + + model = kwargs.get("model", "") + + include_pii = should_send_default_pii() and integration.include_prompts + + with sentry_sdk.start_span( + op=OP.GEN_AI_EMBEDDINGS, + name=f"embeddings {model}".strip(), + origin=CohereIntegration.origin, + ) as span: + set_request_span_data(span, kwargs, integration, COHERE_EMBED_CONFIG) + if include_pii and "texts" in kwargs: + set_request_messages( + span, + _normalize_embedding_input(kwargs["texts"]), + target=SPANDATA.GEN_AI_EMBEDDINGS_INPUT, + ) + + try: + response = f(*args, **kwargs) + except Exception as e: + exc_info = sys.exc_info() + with capture_internal_exceptions(): + _capture_exception(e) + reraise(*exc_info) + + set_response_span_data( + span, response, False, COHERE_EMBED_CONFIG["response"] + ) + return response + + return new_embed + + +def _normalize_embedding_input(texts): + # type: (Any) -> Any + if isinstance(texts, list): + return texts + if isinstance(texts, tuple): + return list(texts) + return [texts] diff --git a/sentry_sdk/integrations/cohere/configs.py b/sentry_sdk/integrations/cohere/configs.py new file mode 100644 index 0000000000..67db83b7d3 --- /dev/null +++ b/sentry_sdk/integrations/cohere/configs.py @@ -0,0 +1,146 @@ +from sentry_sdk.consts import SPANDATA + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any, Dict, Sequence, Tuple + from typing_extensions import TypedDict + + # Source paths: list of attribute chains to try in order. + # e.g. [("meta", "billed_units", "input_tokens"), ("meta", "tokens", "input_tokens")] + SourcePaths = Sequence[Tuple[str, ...]] + + # Maps a SPANDATA key to source paths on the response object. + # e.g. {SPANDATA.GEN_AI_RESPONSE_ID: [("id",)]} + SourceMapping = Dict[str, SourcePaths] + + class UsageConfig(TypedDict, total=False): + """Declarative token usage extraction paths (from response object).""" + + input_tokens: SourcePaths + output_tokens: SourcePaths + total_tokens: SourcePaths + + class ResponseConfig(TypedDict, total=False): + """Declarative response span data config.""" + + # Attributes always extracted from the response object. + sources: SourceMapping + # Attributes extracted only when PII sending is enabled. + pii_sources: SourceMapping + # Declarative token usage paths. + usage: UsageConfig + + class OperationConfig(TypedDict, total=False): + """Full declarative config for an AI operation (chat, embeddings, etc.).""" + + # Key/value pairs set on every span unconditionally. + static: Dict[str, Any] + # Maps kwarg names to SPANDATA keys (always set if present in kwargs). + params: Dict[str, str] + # Maps kwarg names to SPANDATA keys (only set when PII is enabled). + pii_params: Dict[str, str] + # Non-streaming response config. + response: ResponseConfig + # Streaming response config (different attribute paths). + stream_response: ResponseConfig + # Source paths to extract a full response object from a stream-end event + # (V1 pattern: reuse "response" config after extracting). + stream_response_object: SourcePaths + + +# ── Configs ────────────────────────────────────────────────────────────────── + + +COHERE_EMBED_CONFIG: "OperationConfig" = { + "static": { + SPANDATA.GEN_AI_SYSTEM: "cohere", + SPANDATA.GEN_AI_OPERATION_NAME: "embeddings", + }, + "params": {"model": SPANDATA.GEN_AI_REQUEST_MODEL}, + "response": { + "usage": { + "input_tokens": [("meta", "billed_units", "input_tokens")], + "total_tokens": [("meta", "billed_units", "input_tokens")], + }, + }, +} + + +COHERE_V1_CHAT_CONFIG: "OperationConfig" = { + "static": { + SPANDATA.GEN_AI_SYSTEM: "cohere", + SPANDATA.GEN_AI_OPERATION_NAME: "chat", + }, + "response": { + "sources": { + SPANDATA.GEN_AI_RESPONSE_MODEL: [("model",)], + SPANDATA.GEN_AI_RESPONSE_ID: [("generation_id",)], + SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS: [("finish_reason",)], + }, + "pii_sources": { + SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS: [("tool_calls",)], + }, + "usage": { + "input_tokens": [ + ("meta", "billed_units", "input_tokens"), + ("meta", "tokens", "input_tokens"), + ], + "output_tokens": [ + ("meta", "billed_units", "output_tokens"), + ("meta", "tokens", "output_tokens"), + ], + }, + }, + "stream_response_object": [("response",)], +} + + +STREAM_DELTA_TEXT_SOURCES = [("delta", "message", "content", "text")] + + +COHERE_V2_CHAT_CONFIG: "OperationConfig" = { + "static": { + SPANDATA.GEN_AI_SYSTEM: "cohere", + SPANDATA.GEN_AI_OPERATION_NAME: "chat", + }, + "pii_params": { + "tools": SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS, + }, + "response": { + "sources": { + SPANDATA.GEN_AI_RESPONSE_MODEL: [("model",)], + SPANDATA.GEN_AI_RESPONSE_ID: [("id",)], + SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS: [("finish_reason",)], + }, + "pii_sources": { + SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS: [("message", "tool_calls")], + }, + "usage": { + "input_tokens": [ + ("usage", "billed_units", "input_tokens"), + ("usage", "tokens", "input_tokens"), + ], + "output_tokens": [ + ("usage", "billed_units", "output_tokens"), + ("usage", "tokens", "output_tokens"), + ], + }, + }, + "stream_response": { + "sources": { + SPANDATA.GEN_AI_RESPONSE_ID: [("id",)], + SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS: [("delta", "finish_reason")], + }, + "usage": { + "input_tokens": [ + ("delta", "usage", "billed_units", "input_tokens"), + ("delta", "usage", "tokens", "input_tokens"), + ], + "output_tokens": [ + ("delta", "usage", "billed_units", "output_tokens"), + ("delta", "usage", "tokens", "output_tokens"), + ], + }, + }, +} diff --git a/sentry_sdk/integrations/cohere/v1.py b/sentry_sdk/integrations/cohere/v1.py new file mode 100644 index 0000000000..193f863a33 --- /dev/null +++ b/sentry_sdk/integrations/cohere/v1.py @@ -0,0 +1,158 @@ +import sys +from functools import wraps + +from sentry_sdk.ai.span_config import ( + set_request_span_data, + set_request_messages, + set_response_span_data, +) +from sentry_sdk.ai.utils import get_first_from_sources, transform_message_content +from sentry_sdk.consts import OP, SPANDATA +from sentry_sdk.integrations.cohere.configs import COHERE_V1_CHAT_CONFIG + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any, Callable, Iterator + from cohere import StreamedChatResponse + +import sentry_sdk +from sentry_sdk.integrations.cohere import ( + CohereIntegration, + _capture_exception, +) +from sentry_sdk.scope import should_send_default_pii +from sentry_sdk.utils import capture_internal_exceptions, reraise + +try: + from cohere import ChatStreamEndEvent, NonStreamedChatResponse + + try: + from cohere import StreamEndStreamedChatResponse + except ImportError: + from cohere import ( + StreamedChatResponse_StreamEnd as StreamEndStreamedChatResponse, + ) + + _has_chat_types = True +except ImportError: + _has_chat_types = False + + +def setup_v1(wrap_embed_fn): + # type: (Callable[..., Any]) -> None + try: + from cohere.base_client import BaseCohere + from cohere.client import Client + except ImportError: + return + + BaseCohere.chat = _wrap_chat(BaseCohere.chat, streaming=False) + BaseCohere.chat_stream = _wrap_chat(BaseCohere.chat_stream, streaming=True) + Client.embed = wrap_embed_fn(Client.embed) + + +def _wrap_chat(f, streaming): + # type: (Callable[..., Any], bool) -> Callable[..., Any] + if not _has_chat_types: + return f + + @wraps(f) + def new_chat(*args, **kwargs): + # type: (*Any, **Any) -> Any + integration = sentry_sdk.get_client().get_integration(CohereIntegration) + if ( + integration is None + or "message" not in kwargs + or not isinstance(kwargs.get("message"), str) + ): + return f(*args, **kwargs) + + model = kwargs.get("model", "") + include_pii = should_send_default_pii() and integration.include_prompts + + with sentry_sdk.start_span( + op=OP.GEN_AI_CHAT, + name=f"chat {model}".strip(), + origin=CohereIntegration.origin, + ) as span: + try: + response = f(*args, **kwargs) + except Exception as e: + exc_info = sys.exc_info() + with capture_internal_exceptions(): + _capture_exception(e) + reraise(*exc_info) + + with capture_internal_exceptions(): + span_data = { + SPANDATA.GEN_AI_RESPONSE_STREAMING: streaming, + SPANDATA.GEN_AI_REQUEST_MODEL: model if model else None, + } + set_request_span_data( + span, kwargs, integration, COHERE_V1_CHAT_CONFIG, span_data + ) + if include_pii: + set_request_messages(span, _extract_v1_messages(kwargs)) + + if streaming: + return _iter_stream_events(response, span, include_pii) + response_text = ( + _extract_v1_response_text(response) if include_pii else None + ) + set_response_span_data( + span, + response, + include_pii, + COHERE_V1_CHAT_CONFIG["response"], + response_text, + ) + return response + + return new_chat + + +def _iter_stream_events(old_iterator, span, include_pii): + # type: (Any, Any, bool) -> Iterator[StreamedChatResponse] + for x in old_iterator: + with capture_internal_exceptions(): + if isinstance(x, ChatStreamEndEvent) or isinstance( + x, StreamEndStreamedChatResponse + ): + response = get_first_from_sources( + x, COHERE_V1_CHAT_CONFIG["stream_response_object"] + ) + if response is not None: + response_text = ( + _extract_v1_response_text(response) if include_pii else None + ) + set_response_span_data( + span, + response, + include_pii, + COHERE_V1_CHAT_CONFIG["response"], + response_text, + ) + yield x + + +def _extract_v1_messages(kwargs): + # type: (Any) -> list[dict[str, str]] + messages = [] + for x in kwargs.get("chat_history", []): + messages.append( + { + "role": getattr(x, "role", ""), + "content": transform_message_content(getattr(x, "message", "")), + } + ) + message = kwargs.get("message") + if message: + messages.append({"role": "user", "content": transform_message_content(message)}) + return messages + + +def _extract_v1_response_text(response): + # type: (Any) -> list[str] | None + text = getattr(response, "text", None) + return [text] if text is not None else None diff --git a/sentry_sdk/integrations/cohere/v2.py b/sentry_sdk/integrations/cohere/v2.py new file mode 100644 index 0000000000..752551ac73 --- /dev/null +++ b/sentry_sdk/integrations/cohere/v2.py @@ -0,0 +1,172 @@ +import sys +from functools import wraps + +from sentry_sdk.ai.span_config import ( + set_request_span_data, + set_request_messages, + set_response_span_data, +) +from sentry_sdk.ai.utils import ( + get_first_from_sources, + transform_message_content, + transitive_getattr, +) +from sentry_sdk.consts import OP, SPANDATA +from sentry_sdk.integrations.cohere.configs import ( + COHERE_V2_CHAT_CONFIG, + STREAM_DELTA_TEXT_SOURCES, +) + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any, Callable, Iterator + from sentry_sdk.tracing import Span + +import sentry_sdk +from sentry_sdk.integrations.cohere import ( + CohereIntegration, + _capture_exception, +) +from sentry_sdk.scope import should_send_default_pii +from sentry_sdk.utils import capture_internal_exceptions, reraise + +try: + from cohere.v2.client import V2Client as CohereV2Client + + try: + from cohere.v2.types import MessageEndV2ChatStreamResponse, V2ChatResponse + + if TYPE_CHECKING: + from cohere.v2.types import V2ChatStreamResponse + except ImportError: + from cohere.types import ChatResponse as V2ChatResponse + from cohere.types import ( + MessageEndStreamedChatResponseV2 as MessageEndV2ChatStreamResponse, + ) + + if TYPE_CHECKING: + from cohere.types import StreamedChatResponseV2 as V2ChatStreamResponse + + _has_v2 = True +except ImportError: + _has_v2 = False + + +def setup_v2(wrap_embed_fn): + # type: (Callable[..., Any]) -> None + if not _has_v2: + return + CohereV2Client.chat = _wrap_chat_v2(CohereV2Client.chat, streaming=False) + CohereV2Client.chat_stream = _wrap_chat_v2( + CohereV2Client.chat_stream, streaming=True + ) + CohereV2Client.embed = wrap_embed_fn(CohereV2Client.embed) + + +def _wrap_chat_v2(f, streaming): + # type: (Callable[..., Any], bool) -> Callable[..., Any] + @wraps(f) + def new_chat(*args, **kwargs): + # type: (*Any, **Any) -> Any + integration = sentry_sdk.get_client().get_integration(CohereIntegration) + if integration is None or "messages" not in kwargs: + return f(*args, **kwargs) + + model = kwargs.get("model", "") + include_pii = should_send_default_pii() and integration.include_prompts + + with sentry_sdk.start_span( + op=OP.GEN_AI_CHAT, + name=f"chat {model}".strip(), + origin=CohereIntegration.origin, + ) as span: + try: + response = f(*args, **kwargs) + except Exception as e: + exc_info = sys.exc_info() + with capture_internal_exceptions(): + _capture_exception(e) + reraise(*exc_info) + + with capture_internal_exceptions(): + span_data = { + SPANDATA.GEN_AI_RESPONSE_STREAMING: streaming, + SPANDATA.GEN_AI_REQUEST_MODEL: model if model else None, + } + set_request_span_data( + span, kwargs, integration, COHERE_V2_CHAT_CONFIG, span_data + ) + if include_pii: + set_request_messages( + span, _extract_v2_messages(kwargs.get("messages", [])) + ) + + if streaming: + return _iter_v2_stream_events(response, span, include_pii) + response_text = ( + _extract_v2_response_text(response) if include_pii else None + ) + set_response_span_data( + span, + response, + include_pii, + COHERE_V2_CHAT_CONFIG["response"], + response_text, + ) + return response + + return new_chat + + +def _iter_v2_stream_events(old_iterator, span, include_pii): + # type: (Any, Span, bool) -> Iterator[V2ChatStreamResponse] + collected_text = [] # type: list[str] + for x in old_iterator: + with capture_internal_exceptions(): + _append_stream_delta_text(collected_text, x) + if isinstance(x, MessageEndV2ChatStreamResponse): + response_text = ( + ["".join(collected_text)] + if include_pii and collected_text + else None + ) + set_response_span_data( + span, + x, + include_pii, + COHERE_V2_CHAT_CONFIG["stream_response"], + response_text, + ) + yield x + + +def _append_stream_delta_text(collected_text, event): + # type: (list[str], Any) -> None + if transitive_getattr(event, "type") != "content-delta": + return + content_text = get_first_from_sources(event, STREAM_DELTA_TEXT_SOURCES) + if content_text is not None: + collected_text.append(content_text) + + +def _extract_v2_messages(messages): + # type: (Any) -> list[dict[str, Any]] + result = [] + for msg in messages: + role = msg["role"] if isinstance(msg, dict) else getattr(msg, "role", "unknown") + content = ( + msg["content"] if isinstance(msg, dict) else getattr(msg, "content", "") + ) + result.append({"role": role, "content": transform_message_content(content)}) + return result + + +def _extract_v2_response_text(response): + # type: (Any) -> list[str] | None + content = get_first_from_sources(response, [("message", "content")], True) + if content: + texts = [item.text for item in content if hasattr(item, "text")] + if texts: + return texts + return None diff --git a/tests/integrations/cohere/test_cohere.py b/tests/integrations/cohere/test_cohere.py index 9ff56ed697..d1a0657e54 100644 --- a/tests/integrations/cohere/test_cohere.py +++ b/tests/integrations/cohere/test_cohere.py @@ -2,21 +2,32 @@ import httpx import pytest +from unittest import mock + +from httpx import Client as HTTPXClient + from cohere import Client, ChatMessage from sentry_sdk import start_transaction from sentry_sdk.consts import SPANDATA from sentry_sdk.integrations.cohere import CohereIntegration -from unittest import mock # python 3.3 and above -from httpx import Client as HTTPXClient +try: + from cohere import ClientV2 + + has_v2 = True +except ImportError: + has_v2 = False + + +# --- V1 Chat (non-streaming) --- @pytest.mark.parametrize( "send_default_pii, include_prompts", [(True, True), (True, False), (False, True), (False, False)], ) -def test_nonstreaming_chat( +def test_v1_nonstreaming_chat( sentry_init, capture_events, send_default_pii, include_prompts ): sentry_init( @@ -32,6 +43,8 @@ def test_nonstreaming_chat( 200, json={ "text": "the model response", + "generation_id": "gen-123", + "finish_reason": "COMPLETE", "meta": { "billed_units": { "output_tokens": 10, @@ -47,40 +60,41 @@ def test_nonstreaming_chat( model="some-model", chat_history=[ChatMessage(role="SYSTEM", message="some context")], message="hello", - ).text + ) - assert response == "the model response" + assert response.text == "the model response" tx = events[0] assert tx["type"] == "transaction" span = tx["spans"][0] - assert span["op"] == "ai.chat_completions.create.cohere" - assert span["data"][SPANDATA.AI_MODEL_ID] == "some-model" + assert span["op"] == "gen_ai.chat" + assert span["origin"] == "auto.ai.cohere" + assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "cohere" + assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat" + assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "some-model" + assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is False if send_default_pii and include_prompts: - assert ( - '{"role": "system", "content": "some context"}' - in span["data"][SPANDATA.AI_INPUT_MESSAGES] - ) - assert ( - '{"role": "user", "content": "hello"}' - in span["data"][SPANDATA.AI_INPUT_MESSAGES] - ) - assert "the model response" in span["data"][SPANDATA.AI_RESPONSES] + assert "hello" in span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES] + assert "the model response" in span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] else: - assert SPANDATA.AI_INPUT_MESSAGES not in span["data"] - assert SPANDATA.AI_RESPONSES not in span["data"] + assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"] + assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"] assert span["data"]["gen_ai.usage.output_tokens"] == 10 assert span["data"]["gen_ai.usage.input_tokens"] == 20 assert span["data"]["gen_ai.usage.total_tokens"] == 30 -# noinspection PyTypeChecker +# --- V1 Chat (streaming) --- + + @pytest.mark.parametrize( "send_default_pii, include_prompts", [(True, True), (True, False), (False, True), (False, False)], ) -def test_streaming_chat(sentry_init, capture_events, send_default_pii, include_prompts): +def test_v1_streaming_chat( + sentry_init, capture_events, send_default_pii, include_prompts +): sentry_init( integrations=[CohereIntegration(include_prompts=include_prompts)], traces_sample_rate=1.0, @@ -102,6 +116,7 @@ def test_streaming_chat(sentry_init, capture_events, send_default_pii, include_p "finish_reason": "COMPLETE", "response": { "text": "the model response", + "generation_id": "gen-123", "meta": { "billed_units": { "output_tokens": 10, @@ -130,29 +145,29 @@ def test_streaming_chat(sentry_init, capture_events, send_default_pii, include_p tx = events[0] assert tx["type"] == "transaction" span = tx["spans"][0] - assert span["op"] == "ai.chat_completions.create.cohere" - assert span["data"][SPANDATA.AI_MODEL_ID] == "some-model" + assert span["op"] == "gen_ai.chat" + assert span["origin"] == "auto.ai.cohere" + assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "cohere" + assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat" + assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "some-model" + assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True if send_default_pii and include_prompts: - assert ( - '{"role": "system", "content": "some context"}' - in span["data"][SPANDATA.AI_INPUT_MESSAGES] - ) - assert ( - '{"role": "user", "content": "hello"}' - in span["data"][SPANDATA.AI_INPUT_MESSAGES] - ) - assert "the model response" in span["data"][SPANDATA.AI_RESPONSES] + assert "hello" in span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES] + assert "the model response" in span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] else: - assert SPANDATA.AI_INPUT_MESSAGES not in span["data"] - assert SPANDATA.AI_RESPONSES not in span["data"] + assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"] + assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"] assert span["data"]["gen_ai.usage.output_tokens"] == 10 assert span["data"]["gen_ai.usage.input_tokens"] == 20 assert span["data"]["gen_ai.usage.total_tokens"] == 30 -def test_bad_chat(sentry_init, capture_events): +# --- V1 Error --- + + +def test_v1_bad_chat(sentry_init, capture_events): sentry_init(integrations=[CohereIntegration()], traces_sample_rate=1.0) events = capture_events() @@ -167,7 +182,25 @@ def test_bad_chat(sentry_init, capture_events): assert event["level"] == "error" -def test_span_status_error(sentry_init, capture_events): +def test_v1_streaming_error_propagates(sentry_init, capture_events): + """Stream errors must not be silently swallowed by capture_internal_exceptions.""" + sentry_init(integrations=[CohereIntegration()], traces_sample_rate=1.0) + events = capture_events() + + from sentry_sdk.integrations.cohere.v1 import _iter_stream_events + + def failing_iterator(): + yield "event1" + raise ConnectionError("stream interrupted") + + with start_transaction(name="cohere tx") as tx: + span = tx.start_child(op="gen_ai.chat") + with pytest.raises(ConnectionError, match="stream interrupted"): + list(_iter_stream_events(failing_iterator(), span, False)) + span.finish() + + +def test_v1_span_status_error(sentry_init, capture_events): sentry_init(integrations=[CohereIntegration()], traces_sample_rate=1.0) events = capture_events() @@ -186,11 +219,14 @@ def test_span_status_error(sentry_init, capture_events): assert transaction["contexts"]["trace"]["status"] == "internal_error" +# --- V1 Embed --- + + @pytest.mark.parametrize( "send_default_pii, include_prompts", [(True, True), (True, False), (False, True), (False, False)], ) -def test_embed(sentry_init, capture_events, send_default_pii, include_prompts): +def test_v1_embed(sentry_init, capture_events, send_default_pii, include_prompts): sentry_init( integrations=[CohereIntegration(include_prompts=include_prompts)], traces_sample_rate=1.0, @@ -217,67 +253,236 @@ def test_embed(sentry_init, capture_events, send_default_pii, include_prompts): ) with start_transaction(name="cohere tx"): - response = client.embed(texts=["hello"], model="text-embedding-3-large") + response = client.embed(texts=["hello"], model="embed-english-v3.0") assert len(response.embeddings[0]) == 3 tx = events[0] assert tx["type"] == "transaction" span = tx["spans"][0] - assert span["op"] == "ai.embeddings.create.cohere" + assert span["op"] == "gen_ai.embeddings" + assert span["origin"] == "auto.ai.cohere" + assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "cohere" + assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "embeddings" + if send_default_pii and include_prompts: - assert "hello" in span["data"][SPANDATA.AI_INPUT_MESSAGES] + assert "hello" in span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT] else: - assert SPANDATA.AI_INPUT_MESSAGES not in span["data"] + assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"] assert span["data"]["gen_ai.usage.input_tokens"] == 10 assert span["data"]["gen_ai.usage.total_tokens"] == 10 -def test_span_origin_chat(sentry_init, capture_events): +# --- V2 Chat (non-streaming) --- + + +@pytest.mark.skipif(not has_v2, reason="Cohere V2 client not available") +@pytest.mark.parametrize( + "send_default_pii, include_prompts", + [(True, True), (True, False), (False, True), (False, False)], +) +def test_v2_nonstreaming_chat( + sentry_init, capture_events, send_default_pii, include_prompts +): sentry_init( - integrations=[CohereIntegration()], + integrations=[CohereIntegration(include_prompts=include_prompts)], traces_sample_rate=1.0, + send_default_pii=send_default_pii, ) events = capture_events() - client = Client(api_key="z") + client = ClientV2(api_key="z") HTTPXClient.request = mock.Mock( return_value=httpx.Response( 200, json={ - "text": "the model response", - "meta": { + "id": "resp-123", + "model": "some-model", + "finish_reason": "COMPLETE", + "message": { + "role": "assistant", + "content": [{"type": "text", "text": "the model response"}], + }, + "usage": { "billed_units": { - "output_tokens": 10, "input_tokens": 20, - } + "output_tokens": 10, + }, + "tokens": { + "input_tokens": 25, + "output_tokens": 15, + }, }, }, ) ) with start_transaction(name="cohere tx"): + response = client.chat( + model="some-model", + messages=[ + {"role": "system", "content": "some context"}, + {"role": "user", "content": "hello"}, + ], + ) + + assert response.message.content[0].text == "the model response" + tx = events[0] + assert tx["type"] == "transaction" + span = tx["spans"][0] + assert span["op"] == "gen_ai.chat" + assert span["origin"] == "auto.ai.cohere" + assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "cohere" + assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat" + assert span["data"][SPANDATA.GEN_AI_RESPONSE_MODEL] == "some-model" + assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is False + assert span["data"][SPANDATA.GEN_AI_RESPONSE_ID] == "resp-123" + + if send_default_pii and include_prompts: + assert "hello" in span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES] + assert "the model response" in span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] + else: + assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"] + assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"] + + assert span["data"]["gen_ai.usage.output_tokens"] == 10 + assert span["data"]["gen_ai.usage.input_tokens"] == 20 + assert span["data"]["gen_ai.usage.total_tokens"] == 30 + + +# --- V2 Chat (streaming) --- + + +@pytest.mark.skipif(not has_v2, reason="Cohere V2 client not available") +@pytest.mark.parametrize( + "send_default_pii, include_prompts", + [(True, True), (True, False), (False, True), (False, False)], +) +def test_v2_streaming_chat( + sentry_init, capture_events, send_default_pii, include_prompts +): + sentry_init( + integrations=[CohereIntegration(include_prompts=include_prompts)], + traces_sample_rate=1.0, + send_default_pii=send_default_pii, + ) + events = capture_events() + + client = ClientV2(api_key="z") + + # SSE format: each event is "data: ...\n\n" + sse_content = "".join( + [ + 'data: {"type":"message-start","id":"resp-123"}\n', + "\n", + 'data: {"type":"content-delta","index":0,"delta":{"type":"content-delta","message":{"role":"assistant","content":{"type":"text","text":"the model "}}}}\n', + "\n", + 'data: {"type":"content-delta","index":0,"delta":{"type":"content-delta","message":{"role":"assistant","content":{"type":"text","text":"response"}}}}\n', + "\n", + 'data: {"type":"message-end","id":"resp-123","delta":{"finish_reason":"COMPLETE","usage":{"billed_units":{"input_tokens":20,"output_tokens":10},"tokens":{"input_tokens":25,"output_tokens":15}}}}\n', + "\n", + ] + ) + + HTTPXClient.send = mock.Mock( + return_value=httpx.Response( + 200, + content=sse_content, + headers={"content-type": "text/event-stream"}, + ) + ) + + with start_transaction(name="cohere tx"): + responses = list( + client.chat_stream( + model="some-model", + messages=[ + {"role": "user", "content": "hello"}, + ], + ) + ) + + assert len(responses) > 0 + tx = events[0] + assert tx["type"] == "transaction" + span = tx["spans"][0] + assert span["op"] == "gen_ai.chat" + assert span["origin"] == "auto.ai.cohere" + assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "cohere" + assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat" + assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True + + if send_default_pii and include_prompts: + assert "hello" in span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES] + assert "the model response" in span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] + else: + assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"] + assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"] + + assert span["data"]["gen_ai.usage.output_tokens"] == 10 + assert span["data"]["gen_ai.usage.input_tokens"] == 20 + assert span["data"]["gen_ai.usage.total_tokens"] == 30 + + +# --- V2 Error --- + + +@pytest.mark.skipif(not has_v2, reason="Cohere V2 client not available") +def test_v2_streaming_error_propagates(sentry_init, capture_events): + """Stream errors must not be silently swallowed by capture_internal_exceptions.""" + sentry_init(integrations=[CohereIntegration()], traces_sample_rate=1.0) + events = capture_events() + + from sentry_sdk.integrations.cohere.v2 import _iter_v2_stream_events + + def failing_iterator(): + yield "event1" + raise ConnectionError("stream interrupted") + + with start_transaction(name="cohere tx") as tx: + span = tx.start_child(op="gen_ai.chat") + with pytest.raises(ConnectionError, match="stream interrupted"): + list(_iter_v2_stream_events(failing_iterator(), span, False)) + span.finish() + + +@pytest.mark.skipif(not has_v2, reason="Cohere V2 client not available") +def test_v2_bad_chat(sentry_init, capture_events): + sentry_init(integrations=[CohereIntegration()], traces_sample_rate=1.0) + events = capture_events() + + client = ClientV2(api_key="z") + HTTPXClient.request = mock.Mock( + side_effect=httpx.HTTPError("API rate limit reached") + ) + with pytest.raises(httpx.HTTPError): client.chat( model="some-model", - chat_history=[ChatMessage(role="SYSTEM", message="some context")], - message="hello", - ).text + messages=[{"role": "user", "content": "hello"}], + ) (event,) = events + assert event["level"] == "error" - assert event["contexts"]["trace"]["origin"] == "manual" - assert event["spans"][0]["origin"] == "auto.ai.cohere" +# --- V2 Embed --- -def test_span_origin_embed(sentry_init, capture_events): + +@pytest.mark.skipif(not has_v2, reason="Cohere V2 client not available") +@pytest.mark.parametrize( + "send_default_pii, include_prompts", + [(True, True), (True, False), (False, True), (False, False)], +) +def test_v2_embed(sentry_init, capture_events, send_default_pii, include_prompts): sentry_init( - integrations=[CohereIntegration()], + integrations=[CohereIntegration(include_prompts=include_prompts)], traces_sample_rate=1.0, + send_default_pii=send_default_pii, ) events = capture_events() - client = Client(api_key="z") + client = ClientV2(api_key="z") HTTPXClient.request = mock.Mock( return_value=httpx.Response( 200, @@ -285,7 +490,7 @@ def test_span_origin_embed(sentry_init, capture_events): "response_type": "embeddings_floats", "id": "1", "texts": ["hello"], - "embeddings": [[1.0, 2.0, 3.0]], + "embeddings": {"float": [[1.0, 2.0, 3.0]]}, "meta": { "billed_units": { "input_tokens": 10, @@ -296,9 +501,25 @@ def test_span_origin_embed(sentry_init, capture_events): ) with start_transaction(name="cohere tx"): - client.embed(texts=["hello"], model="text-embedding-3-large") + client.embed( + texts=["hello"], + model="embed-english-v3.0", + input_type="search_document", + embedding_types=["float"], + ) - (event,) = events + tx = events[0] + assert tx["type"] == "transaction" + span = tx["spans"][0] + assert span["op"] == "gen_ai.embeddings" + assert span["origin"] == "auto.ai.cohere" + assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "cohere" + assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "embeddings" + + if send_default_pii and include_prompts: + assert "hello" in span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT] + else: + assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"] - assert event["contexts"]["trace"]["origin"] == "manual" - assert event["spans"][0]["origin"] == "auto.ai.cohere" + assert span["data"]["gen_ai.usage.input_tokens"] == 10 + assert span["data"]["gen_ai.usage.total_tokens"] == 10