From adeff87df8f7ac0179082797daf7f3983adf2d7e Mon Sep 17 00:00:00 2001 From: Raquel Barbadillo Date: Wed, 6 May 2026 19:58:00 +0200 Subject: [PATCH] feat: support per-instance OpenTelemetry TracerProvider via set_tracer_provider Add a set_tracer_provider(client, provider) helper that attaches a custom TracerProvider to a specific Mistral client instance. Spans produced by that client are emitted through the custom provider; other instances continue using the global provider. Usage: from mistralai.extra.observability import set_tracer_provider set_tracer_provider(client, my_provider) Generated by Mistral Vibe. Co-Authored-By: Mistral Vibe --- src/mistralai/client/_hooks/tracing.py | 6 +- src/mistralai/extra/observability/__init__.py | 42 +++++- src/mistralai/extra/observability/otel.py | 25 ++-- .../extra/tests/test_otel_tracing.py | 124 ++++++++++++++++++ 4 files changed, 185 insertions(+), 12 deletions(-) diff --git a/src/mistralai/client/_hooks/tracing.py b/src/mistralai/client/_hooks/tracing.py index 632320ce..14c8cffc 100644 --- a/src/mistralai/client/_hooks/tracing.py +++ b/src/mistralai/client/_hooks/tracing.py @@ -2,6 +2,7 @@ from typing import Optional, Tuple, Union import httpx +from opentelemetry import trace from opentelemetry.trace import Span from mistralai.extra.observability.otel import ( @@ -27,6 +28,7 @@ class TracingHook(BeforeRequestHook, AfterSuccessHook, AfterErrorHook): def __init__(self) -> None: + self.tracer_provider: Optional[trace.TracerProvider] = None self.tracing_enabled, self.tracer = get_or_create_otel_tracer() def before_request( @@ -34,7 +36,9 @@ def before_request( ) -> Union[httpx.Request, Exception]: # Refresh tracer/provider per request so tracing can be enabled if the # application configures OpenTelemetry after the client is instantiated. - self.tracing_enabled, self.tracer = get_or_create_otel_tracer() + self.tracing_enabled, self.tracer = get_or_create_otel_tracer( + provider=self.tracer_provider, + ) request, span = get_traced_request_and_span( tracing_enabled=self.tracing_enabled, tracer=self.tracer, diff --git a/src/mistralai/extra/observability/__init__.py b/src/mistralai/extra/observability/__init__.py index 4ff5873c..d3ae6cd9 100644 --- a/src/mistralai/extra/observability/__init__.py +++ b/src/mistralai/extra/observability/__init__.py @@ -1,9 +1,13 @@ from contextlib import contextmanager +from typing import TYPE_CHECKING from opentelemetry import trace as otel_trace from .otel import MISTRAL_SDK_OTEL_TRACER_NAME +if TYPE_CHECKING: + from mistralai.client.sdk import Mistral + @contextmanager def trace(name: str, **kwargs): @@ -12,4 +16,40 @@ def trace(name: str, **kwargs): yield span -__all__ = ["trace"] +def set_tracer_provider( + client: "Mistral", + provider: otel_trace.TracerProvider, +) -> None: + """Attach a per-instance OpenTelemetry TracerProvider to a Mistral client. + + When set, all SDK spans produced by *client* will be emitted through + *provider* instead of the global TracerProvider. + + Usage:: + + from opentelemetry.sdk.trace import TracerProvider + from mistralai.client import Mistral + from mistralai.extra.observability import set_tracer_provider + + client = Mistral(api_key="...") + set_tracer_provider(client, TracerProvider()) + """ + from mistralai.client._hooks.tracing import TracingHook + + hooks = getattr(client.sdk_configuration, "_hooks", None) + if hooks is None: + raise ValueError( + "Cannot set tracer_provider: SDK hooks not initialised on this client." + ) + + for hook in hooks.before_request_hooks: + if isinstance(hook, TracingHook): + hook.tracer_provider = provider + return + + raise ValueError( + "Cannot set tracer_provider: TracingHook not found in the client's hooks." + ) + + +__all__ = ["trace", "set_tracer_provider"] diff --git a/src/mistralai/extra/observability/otel.py b/src/mistralai/extra/observability/otel.py index 7c75271e..f71fb301 100644 --- a/src/mistralai/extra/observability/otel.py +++ b/src/mistralai/extra/observability/otel.py @@ -417,25 +417,30 @@ def _enrich_span_from_response( _enrich_ocr(span, response_data) -def get_or_create_otel_tracer() -> tuple[bool, Tracer]: +def get_or_create_otel_tracer( + provider: trace.TracerProvider | None = None, +) -> tuple[bool, Tracer]: """ - Get a tracer from the current TracerProvider. + Get a tracer from the given or global TracerProvider. - The SDK does not set up its own TracerProvider - it relies on the application - to configure OpenTelemetry. This follows OTEL best practices where: - - Libraries/SDKs get tracers from the global provider - - Applications configure the TracerProvider + When *provider* is supplied (per-instance tracer provider), the tracer is + obtained from it directly. Otherwise the global provider is used, following + the standard OTEL library convention. - If no TracerProvider is configured, the ProxyTracerProvider (default) will - return a NoOp tracer, effectively disabling tracing. Once the application - sets up a real TracerProvider, subsequent spans will be recorded. + If no TracerProvider is configured (neither custom nor global), the + ProxyTracerProvider (default) will return a NoOp tracer, effectively + disabling tracing. Returns: Tuple[bool, Tracer]: (tracing_enabled, tracer) - tracing_enabled is True if a real TracerProvider is configured - tracer is always valid (may be NoOp if no provider configured) """ - tracer_provider = trace.get_tracer_provider() + if provider is not None: + tracer_provider = provider + else: + tracer_provider = trace.get_tracer_provider() + tracer = tracer_provider.get_tracer(MISTRAL_SDK_OTEL_TRACER_NAME) # Tracing is considered enabled if we have a real TracerProvider (not the default proxy) diff --git a/src/mistralai/extra/tests/test_otel_tracing.py b/src/mistralai/extra/tests/test_otel_tracing.py index 818d38e3..82b78ec1 100644 --- a/src/mistralai/extra/tests/test_otel_tracing.py +++ b/src/mistralai/extra/tests/test_otel_tracing.py @@ -1711,5 +1711,129 @@ async def _run(): ) +class TestPerInstanceTracerProvider(unittest.TestCase): + """Tests for per-instance tracer_provider support via set_tracer_provider.""" + + def test_custom_provider_captures_spans(self): + """Spans go to the instance-specific exporter, not the global provider.""" + # Create a standalone provider with its own exporter + custom_exporter = InMemorySpanExporter() + custom_provider = TracerProvider() + custom_provider.add_span_processor(SimpleSpanProcessor(custom_exporter)) + + # Clear the global exporter to ensure spans don't land there + _EXPORTER.clear() + + # Set the custom provider on the hook directly (as set_tracer_provider does) + hook = TracingHook() + hook.tracer_provider = custom_provider + + hook_ctx = _make_hook_context("chat_completion") + + request_body = _dump( + ChatCompletionRequest( + model="mistral-small-latest", + messages=[UserMessage(content="hello")], + ) + ) + request = _make_httpx_request(request_body) + + result = hook.before_request(BeforeRequestContext(hook_ctx), request) + assert isinstance(result, httpx.Request) + + response_body = _dump( + ChatCompletionResponse( + id="custom-prov-1", + object="chat.completion", + model="mistral-small-latest", + created=1234567890, + choices=[ + ChatCompletionChoice( + index=0, + message=AssistantMessage(content="hi"), + finish_reason="stop", + ) + ], + usage=UsageInfo(prompt_tokens=5, completion_tokens=3, total_tokens=8), + ) + ) + response = _make_httpx_response(response_body) + response.request = result + hook.after_success(AfterSuccessContext(hook_ctx), response) + + # Spans should be in the custom exporter + custom_spans = custom_exporter.get_finished_spans() + self.assertEqual(len(custom_spans), 1) + self.assertEqual(custom_spans[0].attributes.get("gen_ai.request.model"), "mistral-small-latest") + + # Global exporter should NOT have received the span + global_spans = [ + s for s in _EXPORTER.get_finished_spans() + if s.attributes.get("gen_ai.response.id") == "custom-prov-1" + ] + self.assertEqual(len(global_spans), 0) + + def test_fallback_to_global_provider(self): + """When tracer_provider is None (default), spans go to the global provider.""" + _EXPORTER.clear() + + hook = TracingHook() + # tracer_provider defaults to None — should use global provider + self.assertIsNone(hook.tracer_provider) + + hook_ctx = _make_hook_context("chat_completion") + + request_body = _dump( + ChatCompletionRequest( + model="mistral-small-latest", + messages=[UserMessage(content="fallback test")], + ) + ) + request = _make_httpx_request(request_body) + result = hook.before_request(BeforeRequestContext(hook_ctx), request) + assert isinstance(result, httpx.Request) + + response_body = _dump( + ChatCompletionResponse( + id="fallback-1", + object="chat.completion", + model="mistral-small-latest", + created=1234567890, + choices=[ + ChatCompletionChoice( + index=0, + message=AssistantMessage(content="response"), + finish_reason="stop", + ) + ], + usage=UsageInfo(prompt_tokens=5, completion_tokens=3, total_tokens=8), + ) + ) + response = _make_httpx_response(response_body) + response.request = result + hook.after_success(AfterSuccessContext(hook_ctx), response) + + # Spans should be in the global exporter + global_spans = [ + s for s in _EXPORTER.get_finished_spans() + if s.attributes.get("gen_ai.response.id") == "fallback-1" + ] + self.assertEqual(len(global_spans), 1) + + def test_set_tracer_provider_helper(self): + """set_tracer_provider(client, provider) sets the provider on the TracingHook.""" + from mistralai.extra.observability import set_tracer_provider + + custom_provider = TracerProvider() + client = Mistral(api_key="test-key") + set_tracer_provider(client, custom_provider) + + # Verify the TracingHook now has the custom provider + hooks = client.sdk_configuration.__dict__["_hooks"] + tracing_hooks = [h for h in hooks.before_request_hooks if isinstance(h, TracingHook)] + self.assertEqual(len(tracing_hooks), 1) + self.assertIs(tracing_hooks[0].tracer_provider, custom_provider) + + if __name__ == "__main__": unittest.main()