Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/mistralai/client/_hooks/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional, Tuple, Union

import httpx
from opentelemetry import trace
from opentelemetry.trace import Span

from mistralai.extra.observability.otel import (
Expand All @@ -27,14 +28,17 @@

class TracingHook(BeforeRequestHook, AfterSuccessHook, AfterErrorHook):
def __init__(self) -> None:
self.tracer_provider: Optional[trace.TracerProvider] = None
self.tracing_enabled, self.tracer = get_or_create_otel_tracer()

def before_request(
self, hook_ctx: BeforeRequestContext, request: httpx.Request
) -> Union[httpx.Request, Exception]:
# Refresh tracer/provider per request so tracing can be enabled if the
# application configures OpenTelemetry after the client is instantiated.
self.tracing_enabled, self.tracer = get_or_create_otel_tracer()
self.tracing_enabled, self.tracer = get_or_create_otel_tracer(
provider=self.tracer_provider,
)
request, span = get_traced_request_and_span(
tracing_enabled=self.tracing_enabled,
tracer=self.tracer,
Expand Down
42 changes: 41 additions & 1 deletion src/mistralai/extra/observability/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from contextlib import contextmanager
from typing import TYPE_CHECKING

from opentelemetry import trace as otel_trace

from .otel import MISTRAL_SDK_OTEL_TRACER_NAME

if TYPE_CHECKING:
from mistralai.client.sdk import Mistral


@contextmanager
def trace(name: str, **kwargs):
Expand All @@ -12,4 +16,40 @@ def trace(name: str, **kwargs):
yield span


__all__ = ["trace"]
def set_tracer_provider(
client: "Mistral",
provider: otel_trace.TracerProvider,
) -> None:
"""Attach a per-instance OpenTelemetry TracerProvider to a Mistral client.

When set, all SDK spans produced by *client* will be emitted through
*provider* instead of the global TracerProvider.

Usage::

from opentelemetry.sdk.trace import TracerProvider
from mistralai.client import Mistral
from mistralai.extra.observability import set_tracer_provider

client = Mistral(api_key="...")
set_tracer_provider(client, TracerProvider())
"""
from mistralai.client._hooks.tracing import TracingHook

hooks = getattr(client.sdk_configuration, "_hooks", None)
if hooks is None:
raise ValueError(
"Cannot set tracer_provider: SDK hooks not initialised on this client."
)

for hook in hooks.before_request_hooks:
if isinstance(hook, TracingHook):
hook.tracer_provider = provider
return

raise ValueError(
"Cannot set tracer_provider: TracingHook not found in the client's hooks."
)


__all__ = ["trace", "set_tracer_provider"]
25 changes: 15 additions & 10 deletions src/mistralai/extra/observability/otel.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,25 +417,30 @@ def _enrich_span_from_response(
_enrich_ocr(span, response_data)


def get_or_create_otel_tracer() -> tuple[bool, Tracer]:
def get_or_create_otel_tracer(
provider: trace.TracerProvider | None = None,
) -> tuple[bool, Tracer]:
"""
Get a tracer from the current TracerProvider.
Get a tracer from the given or global TracerProvider.

The SDK does not set up its own TracerProvider - it relies on the application
to configure OpenTelemetry. This follows OTEL best practices where:
- Libraries/SDKs get tracers from the global provider
- Applications configure the TracerProvider
When *provider* is supplied (per-instance tracer provider), the tracer is
obtained from it directly. Otherwise the global provider is used, following
the standard OTEL library convention.

If no TracerProvider is configured, the ProxyTracerProvider (default) will
return a NoOp tracer, effectively disabling tracing. Once the application
sets up a real TracerProvider, subsequent spans will be recorded.
If no TracerProvider is configured (neither custom nor global), the
ProxyTracerProvider (default) will return a NoOp tracer, effectively
disabling tracing.

Returns:
Tuple[bool, Tracer]: (tracing_enabled, tracer)
- tracing_enabled is True if a real TracerProvider is configured
- tracer is always valid (may be NoOp if no provider configured)
"""
tracer_provider = trace.get_tracer_provider()
if provider is not None:
tracer_provider = provider
else:
tracer_provider = trace.get_tracer_provider()

tracer = tracer_provider.get_tracer(MISTRAL_SDK_OTEL_TRACER_NAME)

# Tracing is considered enabled if we have a real TracerProvider (not the default proxy)
Expand Down
124 changes: 124 additions & 0 deletions src/mistralai/extra/tests/test_otel_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1711,5 +1711,129 @@ async def _run():
)


class TestPerInstanceTracerProvider(unittest.TestCase):
"""Tests for per-instance tracer_provider support via set_tracer_provider."""

def test_custom_provider_captures_spans(self):
"""Spans go to the instance-specific exporter, not the global provider."""
# Create a standalone provider with its own exporter
custom_exporter = InMemorySpanExporter()
custom_provider = TracerProvider()
custom_provider.add_span_processor(SimpleSpanProcessor(custom_exporter))

# Clear the global exporter to ensure spans don't land there
_EXPORTER.clear()

# Set the custom provider on the hook directly (as set_tracer_provider does)
hook = TracingHook()
hook.tracer_provider = custom_provider

hook_ctx = _make_hook_context("chat_completion")

request_body = _dump(
ChatCompletionRequest(
model="mistral-small-latest",
messages=[UserMessage(content="hello")],
)
)
request = _make_httpx_request(request_body)

result = hook.before_request(BeforeRequestContext(hook_ctx), request)
assert isinstance(result, httpx.Request)

response_body = _dump(
ChatCompletionResponse(
id="custom-prov-1",
object="chat.completion",
model="mistral-small-latest",
created=1234567890,
choices=[
ChatCompletionChoice(
index=0,
message=AssistantMessage(content="hi"),
finish_reason="stop",
)
],
usage=UsageInfo(prompt_tokens=5, completion_tokens=3, total_tokens=8),
)
)
response = _make_httpx_response(response_body)
response.request = result
hook.after_success(AfterSuccessContext(hook_ctx), response)

# Spans should be in the custom exporter
custom_spans = custom_exporter.get_finished_spans()
self.assertEqual(len(custom_spans), 1)
self.assertEqual(custom_spans[0].attributes.get("gen_ai.request.model"), "mistral-small-latest")

# Global exporter should NOT have received the span
global_spans = [
s for s in _EXPORTER.get_finished_spans()
if s.attributes.get("gen_ai.response.id") == "custom-prov-1"
]
self.assertEqual(len(global_spans), 0)

def test_fallback_to_global_provider(self):
"""When tracer_provider is None (default), spans go to the global provider."""
_EXPORTER.clear()

hook = TracingHook()
# tracer_provider defaults to None — should use global provider
self.assertIsNone(hook.tracer_provider)

hook_ctx = _make_hook_context("chat_completion")

request_body = _dump(
ChatCompletionRequest(
model="mistral-small-latest",
messages=[UserMessage(content="fallback test")],
)
)
request = _make_httpx_request(request_body)
result = hook.before_request(BeforeRequestContext(hook_ctx), request)
assert isinstance(result, httpx.Request)

response_body = _dump(
ChatCompletionResponse(
id="fallback-1",
object="chat.completion",
model="mistral-small-latest",
created=1234567890,
choices=[
ChatCompletionChoice(
index=0,
message=AssistantMessage(content="response"),
finish_reason="stop",
)
],
usage=UsageInfo(prompt_tokens=5, completion_tokens=3, total_tokens=8),
)
)
response = _make_httpx_response(response_body)
response.request = result
hook.after_success(AfterSuccessContext(hook_ctx), response)

# Spans should be in the global exporter
global_spans = [
s for s in _EXPORTER.get_finished_spans()
if s.attributes.get("gen_ai.response.id") == "fallback-1"
]
self.assertEqual(len(global_spans), 1)

def test_set_tracer_provider_helper(self):
"""set_tracer_provider(client, provider) sets the provider on the TracingHook."""
from mistralai.extra.observability import set_tracer_provider

custom_provider = TracerProvider()
client = Mistral(api_key="test-key")
set_tracer_provider(client, custom_provider)

# Verify the TracingHook now has the custom provider
hooks = client.sdk_configuration.__dict__["_hooks"]
tracing_hooks = [h for h in hooks.before_request_hooks if isinstance(h, TracingHook)]
self.assertEqual(len(tracing_hooks), 1)
self.assertIs(tracing_hooks[0].tracer_provider, custom_provider)


if __name__ == "__main__":
unittest.main()
Loading