diff --git a/README.md b/README.md index 03dd017b..929e2807 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,7 @@ Mistral AI API: Our Chat Completion and Embeddings APIs specification. Create yo * [Resource Management](#resource-management) * [Debugging](#debugging) * [IDE Support](#ide-support) + * [Telemetry \& Observability](#telemetry--observability) * [Development](#development) * [Contributions](#contributions) @@ -841,7 +842,7 @@ print(res.choices[0].message.content) operations. These operations will expose the stream as [Generator][generator] that can be consumed using a simple `for` loop. The loop will terminate when the server no longer has any events to send and closes the -underlying connection. +underlying connection. The stream is also a [Context Manager][context-manager] and can be used with the `with` statement and will close the underlying connection when the context is exited. @@ -1282,9 +1283,117 @@ Generally, the SDK will work well with most IDEs out of the box. However, when u + +## Telemetry & Observability + +The SDK can emit [OpenTelemetry](https://opentelemetry.io/) traces for the API calls it makes (chat, agents, embeddings, OCR, …), following the +[GenAI semantic conventions](https://opentelemetry.io/docs/specs/semconv/gen-ai/). +Spans capture the operation, model, token usage, and — unless redacted — the input/output messages and tool calls. Telemetry is **opt-in** and lives in the `mistralai.extra.observability` module. + +### Installation + +Install the `telemetry` extra: + +```bash +pip install "mistralai[telemetry]" +# or: uv add "mistralai[telemetry]" +``` + +### Enabling telemetry + +Either set an environment variable before creating the client: + +```bash +export MISTRAL_SDK_TELEMETRY=dedicated # dedicated | global | false +``` + +or configure it in code: + +```python +import os +from mistralai.client import Mistral +from mistralai.extra.observability import configure_telemetry + +with Mistral(api_key=os.environ["MISTRAL_API_KEY"]) as client: + # Dedicated mode (default): the SDK creates and owns an OTLP exporter that + # ships spans to the Mistral telemetry endpoint. Spans are redacted before + # export. + configure_telemetry(client) + + client.chat.complete( + model="mistral-small-latest", + messages=[{"role": "user", "content": "Hello!"}], + ) +``` + +### Provider modes + +`configure_telemetry(client, provider=...)` selects where spans go and who owns the export pipeline: + +| `provider` | Who owns the exporter | Where spans go | Redaction | +| ---------- | --------------------- | -------------- | --------- | +| `"dedicated"` (default) | The SDK | Mistral telemetry endpoint | Applied automatically | +| `"global"` | Your application | Your global OpenTelemetry provider | **Not** applied — you need to wrap your own exporter | +| a `TracerProvider` | Your application | The provider you pass | **Not** applied — you need to wrap your own exporter | + +In `global`/custom modes your application owns the pipeline, so the `redaction` argument is ignored (a warning is logged). Wrap your own exporter with `RedactingSpanExporter` to redact spans there: + +```python +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from mistralai.extra.observability import RedactingSpanExporter, configure_telemetry + +provider = TracerProvider() +provider.add_span_processor( + BatchSpanProcessor(RedactingSpanExporter(OTLPSpanExporter())) +) +trace.set_tracer_provider(provider) + +# SDK spans now flow through your global provider (already redacted above). +configure_telemetry(client, provider="global") +``` + +### Redaction + +In dedicated mode, redaction is on by default. Control it with the `redaction` argument, which also accepts any of the reusable policies from `mistralai.extra.observability`: + +```python +from mistralai.extra.observability import AttributeRedactionPolicy + +configure_telemetry(client) # default policy (regex) +configure_telemetry(client, redaction=AttributeRedactionPolicy()) # very conservative key-oriented policy +configure_telemetry(client, redaction=False) # disabled - no redaction +configure_telemetry( # custom callback to control how attributes are redacted + client, + redaction=lambda key, value: None if "email" in key else value, +) +``` + +| Policy | Strategy | Trade-off | +| ------ | -------- | --------- | +| `RegexRedactionPolicy` (default, `redaction=True`) | Content-oriented: keeps keys and structure, redacts matched substrings (secret tokens plus PII — emails, card-like sequences, IPv4). | Redacts most sensitive data while preserving observability value; may miss free-form PII or secrets not in the pattern set. | +| `AttributeRedactionPolicy` | Key-oriented: redacts whole values for sensitive keys (explicit set, fragment match, or non-primitive value), then scans kept values for secret token patterns. | Very conservative, but erases most prompt/response content. | +| `CallbackRedactionPolicy` (`redaction=`) | Your `(key, value) -> value \| None` masker per attribute; return `None` to drop the attribute. | Full control; you own the logic. | + +*Note: the `RedactingSpanExporter` primitive is reusable by any OpenTelemetry application, independent of the Mistral client.* + +### Environment variables + +| Variable | Description | Default | +| -------- | ----------- | ------- | +| `MISTRAL_SDK_TELEMETRY` | Auto-enable telemetry: `dedicated`, `global`, or `false`. | unset (disabled) | +| `MISTRAL_OTLP_TRACES_ENDPOINT` | Override the OTLP traces endpoint used in dedicated mode. | `https://api.mistral.ai/telemetry/v1/traces` | +| `MISTRAL_SDK_DEBUG_TRACING` | Set to `true` for verbose tracing logs. | `false` | +| `MISTRAL_API_KEY` | Used as the bearer token for the dedicated-mode exporter. | — | + +Runnable examples live in [`examples/mistral/observability`](/examples/mistral/observability). + + # Development ## Contributions -While we value open-source contributions to this SDK, this library is generated programmatically. Any manual changes added to internal files will be overwritten on the next generation. -We look forward to hearing your feedback. Feel free to open a PR or an issue with a proof of concept and we'll do our best to include it in a future release. +While we value open-source contributions to this SDK, this library is generated programmatically. Any manual changes added to internal files will be overwritten on the next generation. +We look forward to hearing your feedback. Feel free to open a PR or an issue with a proof of concept and we'll do our best to include it in a future release. diff --git a/examples/mistral/observability/dedicated_telemetry.py b/examples/mistral/observability/dedicated_telemetry.py new file mode 100644 index 00000000..fe9107a7 --- /dev/null +++ b/examples/mistral/observability/dedicated_telemetry.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python +"""Dedicated telemetry mode. + +The SDK creates and owns an OTLP exporter that ships spans to the Mistral +telemetry endpoint. Spans are redacted before export. + +Requires the telemetry extra: pip install "mistralai[telemetry]" +""" + +import os + +from mistralai.client import Mistral +from mistralai.extra.observability import configure_telemetry + + +def main() -> None: + api_key = os.environ["MISTRAL_API_KEY"] + + with Mistral(api_key=api_key) as client: + # Dedicated mode is the default; redaction is on by default. + configure_telemetry(client) + + response = client.chat.complete( + model="mistral-small-latest", + messages=[{"role": "user", "content": "What is the best French cheese?"}], + ) + print(response.choices[0].message.content) + + +if __name__ == "__main__": + main() diff --git a/examples/mistral/observability/global_provider_with_redaction.py b/examples/mistral/observability/global_provider_with_redaction.py new file mode 100644 index 00000000..c1fad207 --- /dev/null +++ b/examples/mistral/observability/global_provider_with_redaction.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python +"""Global provider mode with application-owned redaction. + +In global (or custom TracerProvider) mode your application owns the OTEL export +pipeline, so `configure_telemetry`'s redaction argument is ignored. To redact +spans you wrap your own exporter with RedactingSpanExporter. + +Requires the telemetry extra: pip install "mistralai[telemetry]" +""" + +import os + +from opentelemetry import trace +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor + +from mistralai.client import Mistral +from mistralai.extra.observability import RedactingSpanExporter, configure_telemetry + + +def main() -> None: + api_key = os.environ["MISTRAL_API_KEY"] + + # Build your own provider and wrap the exporter with redaction. + provider = TracerProvider() + provider.add_span_processor( + BatchSpanProcessor(RedactingSpanExporter(OTLPSpanExporter())) + ) + trace.set_tracer_provider(provider) + + with Mistral(api_key=api_key) as client: + # SDK spans flow through the global provider configured above. + configure_telemetry(client, provider="global") + + response = client.chat.complete( + model="mistral-small-latest", + messages=[{"role": "user", "content": "Say hello."}], + ) + print(response.choices[0].message.content) + + +if __name__ == "__main__": + main() diff --git a/examples/mistral/observability/redaction_policies.py b/examples/mistral/observability/redaction_policies.py new file mode 100644 index 00000000..7f6fb024 --- /dev/null +++ b/examples/mistral/observability/redaction_policies.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python +"""Choosing a redaction policy in dedicated telemetry mode. + +The `redaction` argument accepts: + - True (default): the regex (content-oriented) policy + - False: redaction disabled + - a RedactionPolicy instance (e.g. AttributeRedactionPolicy) + - a (key, value) -> value | None callback + +Requires the telemetry extra: pip install "mistralai[telemetry]" +""" + +import os + +from mistralai.client import Mistral +from mistralai.extra.observability import AttributeRedactionPolicy, configure_telemetry + + +def main() -> None: + api_key = os.environ["MISTRAL_API_KEY"] + + with Mistral(api_key=api_key) as client: + configure_telemetry(client, redaction=AttributeRedactionPolicy()) + + # Alternatives: + # configure_telemetry(client, redaction=False) # disable entirely + # configure_telemetry( # custom callback + # client, + # redaction=lambda key, value: None if "email" in key else value, + # ) + + response = client.chat.complete( + model="mistral-small-latest", + messages=[{"role": "user", "content": "Say hello."}], + ) + print(response.choices[0].message.content) + + +if __name__ == "__main__": + main() diff --git a/src/mistralai/extra/observability/__init__.py b/src/mistralai/extra/observability/__init__.py index 772d55b7..3524329d 100644 --- a/src/mistralai/extra/observability/__init__.py +++ b/src/mistralai/extra/observability/__init__.py @@ -4,6 +4,16 @@ from opentelemetry import trace as otel_trace from .otel import MISTRAL_SDK_OTEL_TRACER_NAME +from .redaction import ( + AttributeRedactionPolicy, + CallbackRedactionPolicy, + RedactingSpanExporter, + RedactionPolicy, + RegexRedactionPolicy, + default_redaction_policy, + redact_span, + resolve_policy, +) from .telemetry import ( TelemetryConfigurationError, configure_telemetry, @@ -46,9 +56,17 @@ def set_tracer_provider( __all__ = [ + "AttributeRedactionPolicy", + "CallbackRedactionPolicy", + "RedactingSpanExporter", + "RedactionPolicy", + "RegexRedactionPolicy", "TelemetryConfigurationError", "configure_telemetry", + "default_redaction_policy", "get_telemetry_tracer", + "redact_span", + "resolve_policy", "set_tracer_provider", "trace", ] diff --git a/src/mistralai/extra/observability/redaction.py b/src/mistralai/extra/observability/redaction.py new file mode 100644 index 00000000..0b2a7a19 --- /dev/null +++ b/src/mistralai/extra/observability/redaction.py @@ -0,0 +1,517 @@ +"""Client-side redaction of telemetry spans before they are exported. + +This module provides an export-time masking layer for OpenTelemetry spans so +PII/secrets never leave the client. It is the primary, reusable primitive: +any OTEL application can wrap the exporter it owns with RedactingSpanExporter, +and the Mistral SDK installs it automatically in dedicated telemetry mode (see +``configure_telemetry``). + +Requires the telemetry dependency extra to run, not to import. +""" + +from __future__ import annotations + +import re +from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Any, Callable, Final, Union, cast + +from opentelemetry.util.types import AttributeValue + +if TYPE_CHECKING: + from opentelemetry.sdk.trace import ReadableSpan + from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult + + # Inherit from the real base only for static analysis: linters verify our + # export/shutdown/force_flush signatures. At runtime the base is object so + # the optional OpenTelemetry SDK is not required to import this module. + _SpanExporterBase = SpanExporter +else: + _SpanExporterBase = object + + +# User-supplied per-attribute masker: given (key, value), return the value +# to keep. Return the value unchanged to keep it, a redacted value to mask it, +# or None to drop the attribute entirely +AttributeMaskCallback = Callable[[str, AttributeValue], AttributeValue | None] +RedactionPolicyLike = Union["RedactionPolicy", AttributeMaskCallback] +DEFAULT_REDACTED_VALUE: Final[str] = "[REDACTED]" + + +class RedactionPolicy(ABC): + """Base class for redaction policies.""" + + @abstractmethod + def redact_attributes( + self, attributes: Mapping[str, AttributeValue] | None + ) -> dict[str, AttributeValue]: + """Return a new attribute mapping with sensitive data removed.""" + raise NotImplementedError + + def redact_span_name(self, name: str) -> str: + """Return the span name to export. Defaults to unchanged.""" + return name + + def redact_status_description(self, description: str | None) -> str | None: + """Return the status description to export. Defaults to unchanged.""" + return description + + +DEFAULT_SENSITIVE_ATTRIBUTE_KEYS: Final[frozenset[str]] = frozenset( + { + "client.address", + "db.query.text", + "db.statement", + "exception.message", + "exception.stacktrace", + "gen_ai.input.messages", + "gen_ai.output.messages", + "gen_ai.tool.definitions", + "gen_ai.tool.call.arguments", + "gen_ai.tool.call.result", + "http.request.body", + "http.request.header.authorization", + "http.request.header.cookie", + "http.response.body", + "http.response.header.set-cookie", + "http.target", + "http.url", + "server.address", + "url.full", + "url.path", + "url.query", + } +) +DEFAULT_SENSITIVE_ATTRIBUTE_FRAGMENTS: Final[frozenset[str]] = frozenset( + { + "api_key", + "argument", + "arguments", + "authorization", + "body", + "completion", + "content", + "cookie", + "input", + "message", + "messages", + "output", + "password", + "payload", + "prompt", + "secret", + "set_cookie", + "token", + } +) +DEFAULT_SAFE_ATTRIBUTE_KEYS: Final[frozenset[str]] = frozenset( + { + "agent.trace.public", + "client.port", + "error.type", + "exception.type", + "gen_ai.agent.name", + "gen_ai.conversation.id", + "gen_ai.operation.name", + "gen_ai.provider.name", + "gen_ai.request.model", + "gen_ai.response.finish_reasons", + "gen_ai.response.id", + "gen_ai.response.model", + "gen_ai.tool.call.id", + "gen_ai.tool.name", + "gen_ai.tool.type", + "http.request.method", + "http.response.status_code", + "network.protocol.name", + "network.protocol.version", + "server.port", + "url.scheme", + } +) +DEFAULT_TOKEN_PATTERNS: Final[tuple[re.Pattern[str], ...]] = ( + re.compile(r"(?i)bearer\s+[a-z0-9._\-]+"), + re.compile(r"\bgh[pousr]_[A-Za-z0-9_]{20,}\b"), + re.compile(r"\bxox[baprs]-[A-Za-z0-9-]{10,}\b"), + re.compile(r"\bsk-[A-Za-z0-9]{20,}\b"), + re.compile(r"\bAKIA[0-9A-Z]{16}\b"), + re.compile(r"\bAIza[0-9A-Za-z_\-]{35}\b"), + re.compile(r"\beyJ[A-Za-z0-9_\-]+\.[A-Za-z0-9_\-]+\.[A-Za-z0-9_\-]+\b"), + re.compile(r"-----BEGIN [A-Z ]*PRIVATE KEY-----"), + re.compile(r"\b[sr]k_(?:live|test)_[0-9A-Za-z]{10,}\b"), + # AI providers + re.compile(r"\bsk-ant-[A-Za-z0-9\-_]{20,}\b"), + re.compile(r"\bsk-proj-[A-Za-z0-9\-_]{20,}\b"), + re.compile(r"\bhf_[A-Za-z0-9]{30,}\b"), + # Dev / infra tokens + re.compile(r"\bgithub_pat_[A-Za-z0-9_]{22,}\b"), + re.compile(r"\bglpat-[A-Za-z0-9\-=_]{20,22}\b"), + re.compile(r"\bshp(?:at|ca|pa|ss)_[a-fA-F0-9]{32}\b"), + re.compile(r"\bsq0(?:atp|csp|idp)-[0-9A-Za-z\-_]{22,43}\b"), + re.compile(r"\bPMAK-[a-zA-Z0-9]{24,59}\b"), + re.compile(r"\bphc_[a-zA-Z0-9_]{43}\b"), + re.compile(r"\brubygems_[a-f0-9]{48}\b"), + re.compile(r"\blin_api_[0-9A-Za-z]{40}\b"), + re.compile(r"pypi-AgEIcHlwaS5vcmc[A-Za-z0-9\-_]{50,}"), + re.compile(r"\bsecret_[A-Za-z0-9]{43}\b"), + re.compile(r"[A-Za-z0-9]{14}\.atlasv1\.[A-Za-z0-9]{60,}"), + re.compile(r"\bSG\.[A-Za-z0-9_\-]{22}\.[A-Za-z0-9_\-]{43}\b"), + re.compile(r"\bpk_(?:live|test)_[0-9a-zA-Z]{24}\b"), + # Webhook URLs (the whole URL is the secret) + re.compile(r"https://hooks\.slack\.com/services/[A-Za-z0-9/+]{40,}"), + re.compile( + r"https://discord(?:app)?\.com/api/webhooks/[0-9]{17,}/[A-Za-z0-9\-_]{60,}" + ), + re.compile(r"https://hooks\.zapier\.com/hooks/catch/[A-Za-z0-9/]{16,}"), +) +_SAFE_KEY_PREFIXES: Final[tuple[str, ...]] = ("gen_ai.usage.",) +_PRIMITIVE_TYPES: Final[tuple[type, ...]] = (str, bool, int, float) + + +class AttributeRedactionPolicy(RedactionPolicy): + """Key-oriented hybrid policy. + + An opt-in, high-recall alternative to the default policy: "safe by default", at the cost + of erasing most prompt/response content. It redacts whole values for keys judged sensitive + (explicit set, fragment match, or non-primitive value), then runs token_patterns over the + values it keeps to redact values. + """ + + def __init__( + self, + *, + sensitive_keys: frozenset[str] = DEFAULT_SENSITIVE_ATTRIBUTE_KEYS, + safe_keys: frozenset[str] = DEFAULT_SAFE_ATTRIBUTE_KEYS, + sensitive_fragments: frozenset[str] = DEFAULT_SENSITIVE_ATTRIBUTE_FRAGMENTS, + token_patterns: Sequence[re.Pattern[str]] = DEFAULT_TOKEN_PATTERNS, + redact_non_primitive: bool = True, + redacted_value: str = DEFAULT_REDACTED_VALUE, + ) -> None: + self._sensitive_keys = sensitive_keys + self._safe_keys = safe_keys + self._sensitive_fragments = sensitive_fragments + self._token_patterns = tuple(token_patterns) + self._redact_non_primitive = redact_non_primitive + self._redacted_value = redacted_value + + def _should_redact(self, key: str, value: object) -> bool: + normalized_key = key.lower() + if normalized_key in self._safe_keys: + return False + if normalized_key.startswith(_SAFE_KEY_PREFIXES): + return False + if normalized_key in self._sensitive_keys: + return True + if self._has_sensitive_fragment(normalized_key): + return True + return self._redact_non_primitive and not isinstance(value, _PRIMITIVE_TYPES) + + def _has_sensitive_fragment(self, normalized_key: str) -> bool: + normalized_words = normalized_key.replace("-", "_").replace(".", "_") + key_fragments = {word for word in normalized_words.split("_") if word} + return any( + fragment in key_fragments or fragment in normalized_words + for fragment in self._sensitive_fragments + ) + + def redact_attributes( + self, attributes: Mapping[str, AttributeValue] | None + ) -> dict[str, AttributeValue]: + redacted: dict[str, AttributeValue] = {} + if attributes is None: + return redacted + + for key, value in attributes.items(): + if self._should_redact(key, value): + redacted[key] = self._redacted_value + continue + redacted[key] = _redact_value( + value, self._token_patterns, self._redacted_value + ) + + return redacted + + def redact_status_description(self, description: str | None) -> str | None: + """Redact error descriptions (they often carry request/response text).""" + if description is None: + return None + return self._redacted_value + + +DEFAULT_PII_SECRET_PATTERNS: Final[tuple[re.Pattern[str], ...]] = ( + *DEFAULT_TOKEN_PATTERNS, + # Email addresses + re.compile(r"\b[A-Za-z0-9._%+\-]+@[A-Za-z0-9.\-]+\.[A-Za-z]{2,}\b"), + # Credit-card-like sequences (13-16 digits, optional spaces/dashes) + re.compile(r"\b(?:\d[ -]?){13,16}\b"), + # IPv4 addresses + re.compile( + r"\b(?:(?:25[0-5]|2[0-4]\d|1?\d?\d)\.){3}(?:25[0-5]|2[0-4]\d|1?\d?\d)\b" + ), +) + + +class RegexRedactionPolicy(RedactionPolicy): + """Content-oriented policy based on regexes. + + This is the default policy. Leaves keys and structure intact, scans string values and + redacts matched substrings. Fewer false positives than AttributeRedactionPolicy and aims + to preserve observability value; may miss free-form PII or secrets not in the default + patterns. + """ + + def __init__( + self, + patterns: Sequence[re.Pattern[str]] = DEFAULT_PII_SECRET_PATTERNS, + *, + redacted_value: str = DEFAULT_REDACTED_VALUE, + ) -> None: + self._patterns = tuple(patterns) + self._redacted_value = redacted_value + + def redact_attributes( + self, attributes: Mapping[str, AttributeValue] | None + ) -> dict[str, AttributeValue]: + redacted: dict[str, AttributeValue] = {} + if attributes is None: + return redacted + + for key, value in attributes.items(): + redacted[key] = _redact_value(value, self._patterns, self._redacted_value) + + return redacted + + def redact_span_name(self, name: str) -> str: + return _redact_text(name, self._patterns, self._redacted_value) + + def redact_status_description(self, description: str | None) -> str | None: + if description is None: + return None + return _redact_text(description, self._patterns, self._redacted_value) + + +class CallbackRedactionPolicy(RedactionPolicy): + """Callback-based policy for users to provide custom redaction capabilities. + + The callback is invoked per attribute and should return the value to keep or None to drop the attribute. + Span name and status description are left unchanged (the callback operates on attributes only). + """ + + def __init__(self, mask_function: AttributeMaskCallback) -> None: + self._mask_function = mask_function + + def redact_attributes( + self, attributes: Mapping[str, AttributeValue] | None + ) -> dict[str, AttributeValue]: + redacted: dict[str, AttributeValue] = {} + if attributes is None: + return redacted + + for key, value in attributes.items(): + masked = self._mask_function(key, value) + if masked is None: + continue + redacted[key] = masked + + return redacted + + +# Helpers +def default_redaction_policy() -> RedactionPolicy: + return RegexRedactionPolicy() + + +def resolve_policy(policy: RedactionPolicyLike | None) -> RedactionPolicy: + if policy is None: + return default_redaction_policy() + if isinstance(policy, RedactionPolicy): + return policy + if callable(policy): + return CallbackRedactionPolicy(policy) + raise TypeError( + "redaction policy must be a RedactionPolicy, a callable, or None; " + f"got {type(policy).__name__}." + ) + + +def resolve_redaction(redaction: RedactionPolicyLike | bool) -> RedactionPolicy | None: + """Resolve redaction setting into a policy or None to disable redaction. + + True yields the default policy, False disables redaction entirely, + and a policy or (key, value)->value | None callback is used as-is. + """ + if redaction is False: + return None + if redaction is True: + return default_redaction_policy() + return resolve_policy(redaction) + + +# SpanExporter wrapper +class RedactingSpanExporter(_SpanExporterBase): + """Wrap any SpanExporter to redact spans before delegating export. + + Example + ------- + >>> exporter = RedactingSpanExporter(OTLPSpanExporter(...)) + >>> provider.add_span_processor(BatchSpanProcessor(exporter)) + """ + + def __init__( + self, + exporter: SpanExporter, + policy: RedactionPolicyLike | None = None, + ) -> None: + _load_span_types() # fail fast if the SDK is unavailable + self._exporter = exporter + self._policy = resolve_policy(policy) + + def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: + redacted = [redact_span(span, self._policy) for span in spans] + return self._exporter.export(redacted) + + def shutdown(self) -> None: + self._exporter.shutdown() + + def force_flush(self, timeout_millis: int = 30_000) -> bool: + return self._exporter.force_flush(timeout_millis) + + +def _load_span_types() -> Any: + """Import the OpenTelemetry SDK span classes needed to rebuild spans. + + Raises a helpful error when the optional ``[telemetry]`` extra is missing. + """ + try: + from opentelemetry.sdk.resources import Resource + from opentelemetry.sdk.trace import Event, ReadableSpan + from opentelemetry.trace import Link, SpanKind, Status, StatusCode + except ImportError as exc: # pragma: no cover + raise ImportError( + "Telemetry redaction requires the optional OpenTelemetry SDK " + "dependencies. Install them with `pip install 'mistralai[telemetry]'` " + "or `uv add 'mistralai[telemetry]'`." + ) from exc + + return _SpanTypes( + Event=Event, + Link=Link, + ReadableSpan=ReadableSpan, + Resource=Resource, + SpanKind=SpanKind, + Status=Status, + StatusCode=StatusCode, + ) + + +class _SpanTypes: + __slots__ = ( + "Event", + "Link", + "ReadableSpan", + "Resource", + "SpanKind", + "Status", + "StatusCode", + ) + + def __init__(self, **types: Any) -> None: + for name, value in types.items(): + setattr(self, name, value) + + +def redact_span(span: ReadableSpan, policy: RedactionPolicy) -> ReadableSpan: + types = _load_span_types() + + attributes = policy.redact_attributes(getattr(span, "attributes", None)) + events = _redact_events(getattr(span, "events", ()) or (), policy, types) + links = _redact_links(getattr(span, "links", ()) or (), policy, types) + resource = _redact_resource(getattr(span, "resource", None), policy, types) + status = _redact_status(getattr(span, "status", None), policy, types) + name = policy.redact_span_name(getattr(span, "name", "") or "") + + return types.ReadableSpan( + name=name, + context=getattr(span, "context", None), + parent=getattr(span, "parent", None), + resource=resource, + attributes=attributes, + events=events, + links=links, + kind=getattr(span, "kind", None) or types.SpanKind.INTERNAL, + status=status, + start_time=getattr(span, "start_time", None), + end_time=getattr(span, "end_time", None), + instrumentation_scope=getattr(span, "instrumentation_scope", None), + ) + + +def _redact_events( + events: Sequence[Any], policy: RedactionPolicy, types: Any +) -> list[Any]: + return [ + types.Event( + name=getattr(event, "name", ""), + attributes=policy.redact_attributes(getattr(event, "attributes", None)), + timestamp=getattr(event, "timestamp", None), + ) + for event in events + ] + + +def _redact_links( + links: Sequence[Any], policy: RedactionPolicy, types: Any +) -> list[Any]: + return [ + types.Link( + context=getattr(link, "context", None), + attributes=policy.redact_attributes(getattr(link, "attributes", None)), + ) + for link in links + ] + + +def _redact_resource(resource: Any, policy: RedactionPolicy, types: Any) -> Any: + if resource is None: + return None + return types.Resource( + attributes=policy.redact_attributes(getattr(resource, "attributes", None)), + schema_url=getattr(resource, "schema_url", ""), + ) + + +def _redact_status(status: Any, policy: RedactionPolicy, types: Any) -> Any: + if status is None: + return types.Status() + status_code = getattr(status, "status_code", None) or types.StatusCode.UNSET + description = policy.redact_status_description(getattr(status, "description", None)) + return types.Status(status_code, description) + + +def _redact_value( + value: AttributeValue, + patterns: Sequence[re.Pattern[str]], + redacted_value: str = DEFAULT_REDACTED_VALUE, +) -> AttributeValue: + if isinstance(value, str): + return _redact_text(value, patterns, redacted_value) + if isinstance(value, (list, tuple)): + items = [ + _redact_text(item, patterns, redacted_value) + if isinstance(item, str) + else item + for item in value + ] + return cast(AttributeValue, tuple(items) if isinstance(value, tuple) else items) + return value + + +def _redact_text( + text: str, + patterns: Sequence[re.Pattern[str]], + redacted_value: str = DEFAULT_REDACTED_VALUE, +) -> str: + redacted = text + for pattern in patterns: + redacted = pattern.sub(redacted_value, redacted) + return redacted diff --git a/src/mistralai/extra/observability/telemetry.py b/src/mistralai/extra/observability/telemetry.py index 6845d2cc..b6c77387 100644 --- a/src/mistralai/extra/observability/telemetry.py +++ b/src/mistralai/extra/observability/telemetry.py @@ -12,13 +12,18 @@ from mistralai.client.utils import get_security_from_env from .otel import MISTRAL_SDK_OTEL_TRACER_NAME, OTEL_SERVICE_NAME +from .redaction import ( + RedactingSpanExporter, + RedactionPolicyLike, + resolve_redaction, +) if TYPE_CHECKING: from opentelemetry.sdk.trace import TracerProvider as SDKTracerProvider + from mistralai.client._hooks.tracing import TracingHook from mistralai.client.sdk import Mistral from mistralai.client.sdkconfiguration import SDKConfiguration - from mistralai.client._hooks.tracing import TracingHook MISTRAL_SDK_TELEMETRY_ENV = "MISTRAL_SDK_TELEMETRY" @@ -86,9 +91,26 @@ def _resolve_mistral_telemetry_env() -> TelemetryProviderMode | None: ) from exc +def _warn_redaction_ignored( + redaction: RedactionPolicyLike | bool, + mode: str, +) -> None: + """Warn when redaction may not happen when user might expect it.""" + if redaction is False: # Explicitly turned off + return + logger.warning( + "Telemetry redaction is only applied in 'dedicated' provider mode, where " + "the Mistral SDK owns the exporter. In %r mode the application owns the " + "export pipeline; wrap your exporter with RedactingSpanExporter to redact " + "spans. Ignoring the redaction argument.", + mode, + ) + + def configure_telemetry( client: "Mistral", provider: str | otel_trace.TracerProvider = TELEMETRY_PROVIDER_DEDICATED, + redaction: RedactionPolicyLike | bool = True, ) -> bool: """Configure telemetry provider mode for a Mistral client. @@ -96,6 +118,20 @@ def configure_telemetry( provider="global" clears the per-client provider so SDK spans use the global OpenTelemetry provider. Passing a TracerProvider attaches it to this client without taking ownership of its lifecycle. + + In dedicated mode, spans are redacted before export (safe by default). + You can control this with the redaction argument: + - True: (default) uses the default policy. It scans string values and redacts matched + secrets/PII substrings while preserving keys and surrounding content + - False: disables redaction + - Other RedactionPolicy classes (e.g. the conservative but destructive + AttributeRedactionPolicy) can be found in the redaction module and provided here + - You can also provide a (key, value) -> value | None callback to customize how attributes + get redacted. Your function should return the modified attribute value or None to drop + the attribute. + Note that redaction has no effect when using the global provider mode or providing your own + telemetry provider, since your application controls the provider then. In that case, wrap + your exporter with redaction.RedactingSpanExporter to redact span before export. """ hooks = getattr(client.sdk_configuration, "_hooks", None) if hooks is None: @@ -105,6 +141,7 @@ def configure_telemetry( if isinstance(provider, str): provider_mode = _resolve_provider_mode(provider) if provider_mode == TELEMETRY_PROVIDER_GLOBAL: + _warn_redaction_ignored(redaction, provider_mode) return _use_global_tracer_provider(hook, replace_existing=True) return configure_telemetry_for_hook( @@ -113,6 +150,7 @@ def configure_telemetry( telemetry=provider_mode, finalizer_owner=client, replace_existing=True, + redaction=redaction, ) if isinstance(provider, bool): @@ -121,6 +159,7 @@ def configure_telemetry( "or an OpenTelemetry TracerProvider." ) + _warn_redaction_ignored(redaction, "custom") _attach_custom_tracer_provider(hook, provider) return True @@ -170,8 +209,14 @@ def configure_telemetry_for_hook( finalizer_owner: Any | None = None, respect_global_provider: bool = False, replace_existing: bool = False, + redaction: RedactionPolicyLike | bool = True, ) -> bool: - """Configure telemetry for a tracing hook when the user has opted in.""" + """Configure telemetry for a tracing hook when the user has opted in. + + In dedicated mode the SDK-owned OTLP exporter is wrapped with a + RedactingSpanExporter unless redaction is False (safe by + default). See configure_telemetry for the accepted values. + """ # Fast path: already resolved and no explicit override requested. if telemetry is None and ( hook._auto_telemetry_provider is not None or hook._telemetry_use_global_provider @@ -232,6 +277,7 @@ def configure_telemetry_for_hook( api_key = _resolve_api_key_from_security(getattr(sdk_config, "security", None)) provider = _create_telemetry_tracer_provider( api_key=api_key, + redaction=redaction, ) _attach_telemetry_provider(hook, provider, finalizer_owner or sdk_config) return True @@ -272,6 +318,7 @@ def _resolve_api_key_from_security(security: Any) -> str: def _create_telemetry_tracer_provider( *, api_key: str | None, + redaction: RedactionPolicyLike | bool = True, ) -> "SDKTracerProvider": ( batch_span_processor_cls, @@ -289,6 +336,9 @@ def _create_telemetry_tracer_provider( endpoint=_resolve_mistral_telemetry_endpoint(), headers={"Authorization": _as_bearer_token(api_key)}, ) + policy = resolve_redaction(redaction) + if policy is not None: + exporter = RedactingSpanExporter(exporter, policy) provider = tracer_provider_cls( resource=resource_cls.create({"service.name": OTEL_SERVICE_NAME}) ) diff --git a/src/mistralai/extra/tests/test_redaction.py b/src/mistralai/extra/tests/test_redaction.py new file mode 100644 index 00000000..16e97e65 --- /dev/null +++ b/src/mistralai/extra/tests/test_redaction.py @@ -0,0 +1,310 @@ +import pytest +from opentelemetry.sdk.trace import ReadableSpan, TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.trace import SpanKind, Status, StatusCode + +from mistralai.extra.observability.redaction import ( + DEFAULT_REDACTED_VALUE, + AttributeRedactionPolicy, + CallbackRedactionPolicy, + RedactingSpanExporter, + RegexRedactionPolicy, + default_redaction_policy, + redact_span, + resolve_policy, + resolve_redaction, +) + + +@pytest.fixture +def attribute_policy() -> AttributeRedactionPolicy: + return AttributeRedactionPolicy() + + +@pytest.fixture +def regex_policy() -> RegexRedactionPolicy: + return RegexRedactionPolicy() + + +class TestAttributeRedactionPolicy: + def test_sensitive_key_redacted_wholesale( + self, attribute_policy: AttributeRedactionPolicy + ): + out = attribute_policy.redact_attributes({"gen_ai.input.messages": "hello"}) + assert out["gen_ai.input.messages"] == DEFAULT_REDACTED_VALUE + + def test_safe_key_kept(self, attribute_policy: AttributeRedactionPolicy): + out = attribute_policy.redact_attributes( + {"gen_ai.request.model": "mistral-large"} + ) + assert out["gen_ai.request.model"] == "mistral-large" + + def test_usage_prefix_kept(self, attribute_policy: AttributeRedactionPolicy): + out = attribute_policy.redact_attributes({"gen_ai.usage.input_tokens": 42}) + assert out["gen_ai.usage.input_tokens"] == 42 + + def test_fragment_match_redacted(self, attribute_policy: AttributeRedactionPolicy): + out = attribute_policy.redact_attributes( + {"custom.prompt.text": "secret prompt"} + ) + assert out["custom.prompt.text"] == DEFAULT_REDACTED_VALUE + + def test_token_pattern_on_kept_string( + self, attribute_policy: AttributeRedactionPolicy + ): + out = attribute_policy.redact_attributes( + {"note": "call token ghp_abcdefghijklmnopqrstuvwxyz0123 now"} + ) + assert out["note"] == "call token [REDACTED] now" + + def test_non_primitive_redacted(self, attribute_policy: AttributeRedactionPolicy): + out = attribute_policy.redact_attributes({"data": ("a", "b")}) + assert out["data"] == DEFAULT_REDACTED_VALUE + + def test_non_primitive_kept_when_disabled(self): + policy = AttributeRedactionPolicy(redact_non_primitive=False) + out = policy.redact_attributes({"safeish.list": ("a", "b")}) + assert out["safeish.list"] == ("a", "b") + + def test_string_sequence_scanned_element_wise_when_kept(self): + policy = AttributeRedactionPolicy(redact_non_primitive=False) + out = policy.redact_attributes( + {"tags": ["plain", "ghp_abcdefghijklmnopqrstuvwxyz0123"]} + ) + assert out["tags"] == ["plain", DEFAULT_REDACTED_VALUE] + + def test_safe_key_string_sequence_scanned( + self, attribute_policy: AttributeRedactionPolicy + ): + out = attribute_policy.redact_attributes( + {"gen_ai.response.finish_reasons": ("stop", "Bearer abc.def")} + ) + assert out["gen_ai.response.finish_reasons"] == ("stop", DEFAULT_REDACTED_VALUE) + + def test_none_attributes_returns_empty( + self, attribute_policy: AttributeRedactionPolicy + ): + assert attribute_policy.redact_attributes(None) == {} + + def test_status_description_redacted( + self, attribute_policy: AttributeRedactionPolicy + ): + assert ( + attribute_policy.redact_status_description("boom: user@x.com") + == DEFAULT_REDACTED_VALUE + ) + assert attribute_policy.redact_status_description(None) is None + + def test_span_name_unchanged(self, attribute_policy: AttributeRedactionPolicy): + assert ( + attribute_policy.redact_span_name("chat mistral-large") + == "chat mistral-large" + ) + + def test_custom_redacted_value(self): + policy = AttributeRedactionPolicy(redacted_value="XXX") + out = policy.redact_attributes({"http.url": "https://x"}) + assert out["http.url"] == "XXX" + + +class TestRegexRedactionPolicy: + def test_email_redacted_inline_preserving_structure( + self, regex_policy: RegexRedactionPolicy + ): + out = regex_policy.redact_attributes( + {"gen_ai.input.messages": '{"content":"reach me at a@b.com"}'} + ) + assert out["gen_ai.input.messages"] == '{"content":"reach me at [REDACTED]"}' + + def test_token_redacted(self, regex_policy: RegexRedactionPolicy): + out = regex_policy.redact_attributes({"h": "Bearer abc.def-ghi"}) + assert out["h"] == "[REDACTED]" + + def test_non_matching_string_kept(self, regex_policy: RegexRedactionPolicy): + out = regex_policy.redact_attributes({"server.address": "prod-host-1"}) + assert out["server.address"] == "prod-host-1" + + def test_non_string_untouched(self, regex_policy: RegexRedactionPolicy): + out = regex_policy.redact_attributes({"n": 5, "b": True}) + assert out == {"n": 5, "b": True} + + def test_span_name_scanned(self, regex_policy: RegexRedactionPolicy): + assert regex_policy.redact_span_name("op a@b.com") == "op [REDACTED]" + + def test_status_description_scanned(self, regex_policy: RegexRedactionPolicy): + assert ( + regex_policy.redact_status_description("failed for a@b.com") + == "failed for [REDACTED]" + ) + + @pytest.mark.parametrize( + "secret", + [ + "AKIAIOSFODNN7EXAMPLE", + "AIzaabcdefghijklmnopqrstuvwxyz012345678", + "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.abc123", + "-----BEGIN RSA PRIVATE KEY-----", + "sk_live_0123456789abcdefghij", + "sk-ant-api03-abcdefghijklmnopqrstuvwxyz012345", + "hf_abcdefghijklmnopqrstuvwxyz0123456789", + "github_pat_abcdefghijklmnopqrstuvwxyz", + "glpat-abcdefghij0123456789ab", + "shpat_0123456789abcdef0123456789abcdef", + "sq0atp-0123456789abcdefghijkl", + "PMAK-0123456789abcdefghijklmn", + "phc_abcdefghijklmnopqrstuvwxyz0123456789abcdefg", + "SG.abcdefghijklmnopqrstuv.abcdefghijklmnopqrstuvwxyz0123456789abcdefg", + "pk_live_0123456789abcdefghijklmn", + "https://hooks.slack.com/services/T00000000/B00000000/abcdefghijklmnopqrstuvwx", + ], + ) + def test_secret_patterns_redacted( + self, regex_policy: RegexRedactionPolicy, secret: str + ): + out = regex_policy.redact_attributes({"v": f"leak {secret} here"}) + value = out["v"] + assert isinstance(value, str) + assert secret not in value + assert DEFAULT_REDACTED_VALUE in value + + def test_string_sequence_scanned_preserving_container( + self, regex_policy: RegexRedactionPolicy + ): + out = regex_policy.redact_attributes({"msgs": ["hello", "reach me at a@b.com"]}) + assert out["msgs"] == ["hello", "reach me at [REDACTED]"] + + def test_tuple_sequence_stays_tuple(self, regex_policy: RegexRedactionPolicy): + out = regex_policy.redact_attributes({"msgs": ("hi", "a@b.com")}) + assert out["msgs"] == ("hi", "[REDACTED]") + + def test_numeric_sequence_untouched(self, regex_policy: RegexRedactionPolicy): + out = regex_policy.redact_attributes({"nums": [1, 2, 3]}) + assert out["nums"] == [1, 2, 3] + + +class TestCallbackRedactionPolicy: + def test_mask_applied_per_attribute(self): + policy = CallbackRedactionPolicy( + lambda key, value: "[X]" if "message" in key else value + ) + out = policy.redact_attributes( + {"gen_ai.output.messages": "hi", "gen_ai.request.model": "m"} + ) + assert out == {"gen_ai.output.messages": "[X]", "gen_ai.request.model": "m"} + + def test_returning_none_drops_attribute(self): + policy = CallbackRedactionPolicy( + lambda key, value: None if key == "drop" else value + ) + out = policy.redact_attributes({"drop": "x", "keep": "y"}) + assert out == {"keep": "y"} + + +class TestResolvePolicy: + def test_none_returns_default(self): + assert isinstance(default_redaction_policy(), RegexRedactionPolicy) + assert isinstance(resolve_policy(None), RegexRedactionPolicy) + + def test_policy_passthrough(self): + policy = RegexRedactionPolicy() + assert resolve_policy(policy) is policy + + def test_callable_wrapped(self): + resolved = resolve_policy(lambda k, v: v) + assert isinstance(resolved, CallbackRedactionPolicy) + + def test_invalid_raises_type_error(self): + with pytest.raises(TypeError): + resolve_policy(123) # type: ignore[arg-type] + + +class TestResolveRedaction: + def test_true_returns_default_policy(self): + assert isinstance(resolve_redaction(True), RegexRedactionPolicy) + + def test_false_returns_none(self): + assert resolve_redaction(False) is None + + def test_policy_passthrough(self): + policy = RegexRedactionPolicy() + assert resolve_redaction(policy) is policy + + def test_callable_wrapped(self): + resolved = resolve_redaction(lambda k, v: v) + assert isinstance(resolved, CallbackRedactionPolicy) + + +class TestRedactSpan: + @staticmethod + def _make_span(): + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + with tracer.start_as_current_span("parent", kind=SpanKind.CLIENT) as span: + span.set_attribute("gen_ai.input.messages", "secret") + span.set_attribute("gen_ai.request.model", "mistral-large") + span.add_event("exception", {"exception.message": "boom"}) + span.set_status(Status(StatusCode.ERROR, "boom detail")) + provider.force_flush() + return exporter.get_finished_spans()[0] + + def test_attributes_redacted(self): + redacted = redact_span(self._make_span(), AttributeRedactionPolicy()) + assert isinstance(redacted, ReadableSpan) + attrs = redacted.attributes + assert attrs is not None + assert attrs["gen_ai.input.messages"] == DEFAULT_REDACTED_VALUE + assert attrs["gen_ai.request.model"] == "mistral-large" + + def test_event_attributes_redacted(self): + redacted = redact_span(self._make_span(), AttributeRedactionPolicy()) + event = redacted.events[0] + assert event.name == "exception" + attrs = event.attributes + assert attrs is not None + assert attrs["exception.message"] == DEFAULT_REDACTED_VALUE + + def test_status_description_redacted(self): + redacted = redact_span(self._make_span(), AttributeRedactionPolicy()) + assert redacted.status.status_code == StatusCode.ERROR + assert redacted.status.description == DEFAULT_REDACTED_VALUE + + def test_identity_preserved(self): + original = self._make_span() + redacted = redact_span(original, AttributeRedactionPolicy()) + assert redacted.context is not None + assert original.context is not None + assert redacted.context.span_id == original.context.span_id + assert redacted.context.trace_id == original.context.trace_id + + +class TestRedactingSpanExporter: + @staticmethod + def _export_through(policy=None): + wrapped = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor( + SimpleSpanProcessor(RedactingSpanExporter(wrapped, policy)) + ) + tracer = provider.get_tracer("test") + with tracer.start_as_current_span("chat") as span: + span.set_attribute("gen_ai.output.messages", "leak Bearer abc.def-ghi") + span.set_attribute("gen_ai.request.model", "mistral-large") + provider.force_flush() + return wrapped.get_finished_spans() + + def test_wrapped_exporter_receives_redacted_spans(self): + spans = self._export_through() + assert len(spans) == 1 + attrs = spans[0].attributes + assert attrs is not None + assert attrs["gen_ai.output.messages"] == f"leak {DEFAULT_REDACTED_VALUE}" + assert attrs["gen_ai.request.model"] == "mistral-large" + + def test_custom_policy_used(self): + spans = self._export_through(AttributeRedactionPolicy()) + attrs = spans[0].attributes + assert attrs is not None + assert attrs["gen_ai.output.messages"] == DEFAULT_REDACTED_VALUE diff --git a/src/mistralai/extra/tests/test_telemetry.py b/src/mistralai/extra/tests/test_telemetry.py index feae82dd..5fc48622 100644 --- a/src/mistralai/extra/tests/test_telemetry.py +++ b/src/mistralai/extra/tests/test_telemetry.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, cast from unittest.mock import MagicMock, patch +import pytest from opentelemetry.sdk.trace import TracerProvider from mistralai.client._hooks import SDKHooks @@ -16,10 +17,16 @@ set_tracer_provider, ) from mistralai.extra.observability.otel import MISTRAL_SDK_OTEL_TRACER_NAME +from mistralai.extra.observability.redaction import ( + AttributeRedactionPolicy, + CallbackRedactionPolicy, + RedactingSpanExporter, + RegexRedactionPolicy, +) from mistralai.extra.observability.telemetry import ( - MISTRAL_TELEMETRY_ENDPOINT, - MISTRAL_SDK_TELEMETRY_ENV, MISTRAL_OTLP_TRACES_ENDPOINT_ENV, + MISTRAL_SDK_TELEMETRY_ENV, + MISTRAL_TELEMETRY_ENDPOINT, TelemetryConfigurationError, _create_telemetry_tracer_provider, configure_telemetry_for_hook, @@ -49,7 +56,9 @@ def _make_client(api_key: str | None = "test-key") -> "Mistral": def _get_tracing_hook(client: "Mistral") -> TracingHook: hooks = client.sdk_configuration.__dict__["_hooks"] - tracing_hooks = [h for h in hooks.before_request_hooks if isinstance(h, TracingHook)] + tracing_hooks = [ + h for h in hooks.before_request_hooks if isinstance(h, TracingHook) + ] assert len(tracing_hooks) == 1 return tracing_hooks[0] @@ -144,6 +153,7 @@ def test_env_dedicated_values_attach_provider(self): self.assertTrue(configured) create_provider.assert_called_once_with( api_key="test-key", + redaction=True, ) self.assertIs(_get_tracing_hook(client).tracer_provider, provider) @@ -207,6 +217,7 @@ def test_configure_telemetry_attaches_per_client_provider(self): self.assertTrue(configured) create_provider.assert_called_once_with( api_key="test-key", + redaction=True, ) self.assertIs(_get_tracing_hook(client).tracer_provider, provider) @@ -224,6 +235,7 @@ def test_configure_telemetry_accepts_explicit_dedicated_provider_mode(self): self.assertTrue(configured) create_provider.assert_called_once_with( api_key="test-key", + redaction=True, ) self.assertIs(_get_tracing_hook(client).tracer_provider, provider) @@ -431,6 +443,7 @@ def test_env_dedicated_uses_mistral_api_key_fallback(self): self.assertTrue(configured) create_provider.assert_called_once_with( api_key="env-key", + redaction=True, ) self.assertIs(_get_tracing_hook(client).tracer_provider, provider) @@ -470,6 +483,7 @@ def test_env_dedicated_ignores_standard_otel_endpoint_env(self): self.assertTrue(configured) create_provider.assert_called_once_with( api_key="test-key", + redaction=True, ) def test_sdk_config_global_uses_global_provider_mode(self): @@ -481,7 +495,9 @@ def test_sdk_config_global_uses_global_provider_mode(self): with patch( "mistralai.extra.observability.telemetry._create_telemetry_tracer_provider" ) as create_provider: - configured = configure_telemetry_for_hook(hook, client.sdk_configuration) + configured = configure_telemetry_for_hook( + hook, client.sdk_configuration + ) self.assertTrue(configured) create_provider.assert_not_called() @@ -531,7 +547,9 @@ def test_configure_telemetry_for_hook_reads_sdk_config_telemetry_flag(self): "mistralai.extra.observability.telemetry._create_telemetry_tracer_provider", return_value=provider, ): - configured = configure_telemetry_for_hook(hook, client.sdk_configuration) + configured = configure_telemetry_for_hook( + hook, client.sdk_configuration + ) self.assertTrue(configured) self.assertIs(hook.tracer_provider, provider) @@ -642,5 +660,115 @@ def test_mistral_endpoint_env_overrides_default_endpoint(self): ) +_TELEMETRY_LOGGER = "mistralai.extra.observability.telemetry" + + +@pytest.fixture +def clear_exporters(): + FakeExporter.instances.clear() + + +class TestTelemetryRedaction: + @staticmethod + def _make_provider(**kwargs): + with patch( + "mistralai.extra.observability.telemetry._load_otel_sdk", + return_value=( + FakeSpanProcessor, + FakeExporter, + FakeResource, + FakeTracerProvider, + ), + ): + return _create_telemetry_tracer_provider(api_key="test-key", **kwargs) + + @staticmethod + def _exporter_of(provider): + assert len(provider.span_processors) == 1 + return provider.span_processors[0].exporter + + def test_dedicated_wraps_exporter_by_default(self, clear_exporters): + provider = self._make_provider() + exporter = self._exporter_of(provider) + assert isinstance(exporter, RedactingSpanExporter) + assert exporter._exporter is FakeExporter.instances[0] + assert isinstance(exporter._policy, RegexRedactionPolicy) + + def test_redaction_true_wraps_with_default_policy(self, clear_exporters): + provider = self._make_provider(redaction=True) + exporter = self._exporter_of(provider) + assert isinstance(exporter, RedactingSpanExporter) + assert isinstance(exporter._policy, RegexRedactionPolicy) + + def test_redaction_false_leaves_exporter_unwrapped(self, clear_exporters): + provider = self._make_provider(redaction=False) + exporter = self._exporter_of(provider) + assert not isinstance(exporter, RedactingSpanExporter) + assert exporter is FakeExporter.instances[0] + + def test_custom_policy_instance_is_used(self, clear_exporters): + policy = AttributeRedactionPolicy() + provider = self._make_provider(redaction=policy) + exporter = self._exporter_of(provider) + assert isinstance(exporter, RedactingSpanExporter) + assert exporter._policy is policy + + def test_callback_is_wrapped_in_callback_policy(self, clear_exporters): + def mask(key, value): + return value + + provider = self._make_provider(redaction=mask) + exporter = self._exporter_of(provider) + assert isinstance(exporter, RedactingSpanExporter) + assert isinstance(exporter._policy, CallbackRedactionPolicy) + + def test_dedicated_mode_forwards_custom_redaction_to_provider(self): + # Wiring test + with patch( + "mistralai.extra.observability.telemetry._create_telemetry_tracer_provider" + ) as create_provider: + create_provider.return_value = FakeProvider() + client = _make_client(api_key="test-key") + policy = RegexRedactionPolicy() + configure_telemetry(client, provider="dedicated", redaction=policy) + + create_provider.assert_called_once_with(api_key="test-key", redaction=policy) + + def test_dedicated_mode_forwards_default_redaction_to_provider(self): + # Wiring test + with patch( + "mistralai.extra.observability.telemetry._create_telemetry_tracer_provider" + ) as create_provider: + create_provider.return_value = FakeProvider() + client = _make_client(api_key="test-key") + configure_telemetry(client, provider="dedicated") + + create_provider.assert_called_once_with(api_key="test-key", redaction=True) + + def test_global_mode_warns_by_default(self, caplog): + client = _make_client(api_key="test-key") + with caplog.at_level("WARNING", logger=_TELEMETRY_LOGGER): + configure_telemetry(client, provider="global") + assert "Telemetry redaction is only applied in 'dedicated'" in caplog.text + + def test_global_mode_does_not_warn_when_redaction_disabled(self, caplog): + client = _make_client(api_key="test-key") + with caplog.at_level("WARNING", logger=_TELEMETRY_LOGGER): + configure_telemetry(client, provider="global", redaction=False) + assert not [r for r in caplog.records if r.name == _TELEMETRY_LOGGER] + + def test_custom_provider_warns_by_default(self, caplog): + client = _make_client(api_key="test-key") + with caplog.at_level("WARNING", logger=_TELEMETRY_LOGGER): + configure_telemetry(client, provider=TracerProvider()) + assert "Telemetry redaction is only applied in 'dedicated'" in caplog.text + + def test_custom_provider_does_not_warn_when_redaction_disabled(self, caplog): + client = _make_client(api_key="test-key") + with caplog.at_level("WARNING", logger=_TELEMETRY_LOGGER): + configure_telemetry(client, provider=TracerProvider(), redaction=False) + assert not [r for r in caplog.records if r.name == _TELEMETRY_LOGGER] + + if __name__ == "__main__": unittest.main()