diff --git a/src/mistralai/client/_hooks/workflow_encoding_hook.py b/src/mistralai/client/_hooks/workflow_encoding_hook.py index f383842d..d65c3ff4 100644 --- a/src/mistralai/client/_hooks/workflow_encoding_hook.py +++ b/src/mistralai/client/_hooks/workflow_encoding_hook.py @@ -5,9 +5,10 @@ import re import uuid import weakref -from typing import Any, Coroutine, Dict, Optional, TypeVar, Union +from typing import Any, AsyncIterator, Coroutine, Dict, Optional, TypeVar, Union import httpx +from httpx._types import AsyncByteStream from .types import ( AfterSuccessContext, @@ -28,9 +29,7 @@ class _WorkflowEncodingConfig: - def __init__( - self, payload_encoder: PayloadEncoder, namespace: str - ) -> None: + def __init__(self, payload_encoder: PayloadEncoder, namespace: str) -> None: self.payload_encoder = payload_encoder self.namespace = namespace @@ -64,7 +63,9 @@ def configure_workflow_encoding( ) -def _get_encoding_config(sdk_config: SDKConfiguration) -> Optional[_WorkflowEncodingConfig]: +def _get_encoding_config( + sdk_config: SDKConfiguration, +) -> Optional[_WorkflowEncodingConfig]: """Get workflow encoding config for a client.""" config_id = getattr(sdk_config, _ENCODING_CONFIG_ID_ATTR, None) if config_id is None: @@ -100,9 +101,31 @@ def _get_encoding_config(sdk_config: SDKConfiguration) -> Optional[_WorkflowEnco "update_workflow_execution_v1_workflows_executions__execution_id__updates_post", } +# Operations that return event data that may need decryption +OPERATIONS_DECODE_EVENTS = { + "get_workflow_events_v1_workflows_events_list_get", +} + +# Streaming operations that return SSE event data that may need decryption +OPERATIONS_DECODE_EVENTS_STREAM = { + "get_stream_events_v1_workflows_events_stream_get", + "stream_v1_workflows_executions__execution_id__stream_get", +} + SCHEDULE_CORRELATION_ID_PLACEHOLDER = "__scheduled_workflow__" +def _is_payload_type(value: Any) -> bool: + """Check if a value is a JSONPayload or JSONPatchPayload by its structure. + + Payload types have: {"type": "json" | "json_patch", "value": ...} + """ + if not isinstance(value, dict): + return False + payload_type = value.get("type") + return payload_type in ("json", "json_patch") and "value" in value + + _T = TypeVar("_T") @@ -143,6 +166,127 @@ def _extract_execution_id_from_body(body: Dict[str, Any]) -> Optional[str]: return body.get("execution_id") +async def _decrypt_event_attributes( + attributes: Dict[str, Any], + payload_encoder: PayloadEncoder, +) -> Dict[str, Any]: + """Decrypt payload fields in event attributes.""" + for field_name, field_value in attributes.items(): + if not _is_payload_type(field_value): + continue + + # Check if it has encoding_options (meaning it's encrypted) + if not field_value.get("encoding_options"): + continue + + # Decrypt the payload + decrypted = await payload_encoder.decode_event_payload(field_value) + attributes[field_name] = decrypted + + return attributes + + +async def _decrypt_events_in_response( + body: Dict[str, Any], + payload_encoder: PayloadEncoder, +) -> Dict[str, Any]: + """Decrypt payload fields in events within a response body.""" + events = body.get("events", []) + if not events: + return body + + for event in events: + attributes = event.get("attributes") + if isinstance(attributes, dict): + event["attributes"] = await _decrypt_event_attributes( + attributes, payload_encoder + ) + + return body + + +def _decrypt_sse_line(line: bytes, payload_encoder: PayloadEncoder) -> bytes: + """Decrypt event payloads in an SSE data line.""" + if not line.startswith(b"data:"): + return line + + try: + data_part = line[5:].strip() + if not data_part: + return line + + event_wrapper = json.loads(data_part) + data = event_wrapper.get("data") + if not isinstance(data, dict): + return line + + attributes = data.get("attributes") + if not isinstance(attributes, dict): + return line + + # Decrypt in place - _decrypt_event_attributes modifies attributes dict + _run_async(_decrypt_event_attributes(attributes, payload_encoder)) + + return b"data: " + json.dumps(event_wrapper).encode("utf-8") + except (json.JSONDecodeError, Exception) as e: + logger.debug("SSE line decryption failed: %s", e) + return line + + +class _DecryptingAsyncByteStream(AsyncByteStream): + """Async byte stream wrapper that decrypts SSE event payloads.""" + + def __init__(self, original_stream: Any, payload_encoder: PayloadEncoder): + self._original = original_stream + self._payload_encoder = payload_encoder + self._buffer = b"" + + async def __aiter__(self) -> AsyncIterator[bytes]: + async for chunk in self._original: + for processed in self._process_chunk(chunk): + yield processed + # Flush remaining buffer + if self._buffer: + yield _decrypt_sse_line(self._buffer, self._payload_encoder) + + def _process_chunk(self, chunk: bytes): + self._buffer += chunk + lines = self._buffer.split(b"\n") + # Keep last incomplete line in buffer + self._buffer = lines[-1] + for line in lines[:-1]: + yield _decrypt_sse_line(line, self._payload_encoder) + b"\n" + + async def aclose(self) -> None: + if hasattr(self._original, "aclose"): + await self._original.aclose() + + +def _wrap_sse_response_with_decryption( + response: httpx.Response, + payload_encoder: PayloadEncoder, +) -> httpx.Response: + """Wrap an SSE response to decrypt event payloads as they stream. + + Creates a new response with a custom stream that decrypts payloads on-the-fly. + """ + # Get the original stream from the response + original_stream = response.stream + + # Create wrapped stream + decrypting_stream = _DecryptingAsyncByteStream(original_stream, payload_encoder) + + # Create new response with wrapped stream + # Use internal _content to avoid reading stream + new_response = httpx.Response( + status_code=response.status_code, + headers=response.headers, + stream=decrypting_stream, + request=response.request, + extensions=response.extensions, + ) + + return new_response class WorkflowEncodingHook(BeforeRequestHook, AfterSuccessHook): @@ -203,12 +347,10 @@ def before_request( execution_id=execution_id, ) - logger.debug( - "WorkflowEncodingHook: Encoding input for %s", hook_ctx.operation_id - ) - encoded_input = _run_async( - encoding_config.payload_encoder.encode_network_input(input_data, context) + encoding_config.payload_encoder.encode_network_input( + input_data, context + ) ) # Update body based on operation type: @@ -241,40 +383,71 @@ def after_success( hook_ctx: AfterSuccessContext, response: httpx.Response, ) -> Union[httpx.Response, Exception]: - """Intercept responses to decode workflow result payloads.""" + """Intercept responses to decode workflow result payloads and event payloads.""" encoding_config = _get_encoding_config(hook_ctx.config) if not encoding_config: return response - if hook_ctx.operation_id not in OPERATIONS_DECODE_RESULT: + content_type = response.headers.get("content-type", "") + + # Handle SSE stream decryption + if hook_ctx.operation_id in OPERATIONS_DECODE_EVENTS_STREAM: + if "text/event-stream" in content_type: + return _wrap_sse_response_with_decryption( + response, encoding_config.payload_encoder + ) return response - content_type = response.headers.get("content-type", "") if "application/json" not in content_type: return response - try: - body = json.loads(response.content) - result = body.get("result") - if result is None or not encoding_config.payload_encoder.check_is_payload_encoded(result): - return response - - logger.debug( - "WorkflowEncodingHook: Decoding result for %s", hook_ctx.operation_id - ) - - decoded_result = _run_async(encoding_config.payload_encoder.decode_network_result(result)) - - body["result"] = decoded_result - new_content = json.dumps(body).encode("utf-8") + # Handle workflow result decoding + if hook_ctx.operation_id in OPERATIONS_DECODE_RESULT: + try: + body = json.loads(response.content) + result = body.get("result") + if ( + result is not None + and encoding_config.payload_encoder.check_is_payload_encoded( + result + ) + ): + decoded_result = _run_async( + encoding_config.payload_encoder.decode_network_result(result) + ) + + body["result"] = decoded_result + new_content = json.dumps(body).encode("utf-8") + + response = httpx.Response( + status_code=response.status_code, + headers=response.headers, + content=new_content, + request=response.request, + extensions=response.extensions, + ) + except Exception as e: + logger.error("WorkflowEncodingHook: Failed to decode result: %s", e) + raise + + # Handle event payload decoding + elif hook_ctx.operation_id in OPERATIONS_DECODE_EVENTS: + try: + body = json.loads(response.content) + body = _run_async( + _decrypt_events_in_response(body, encoding_config.payload_encoder) + ) + new_content = json.dumps(body).encode("utf-8") + + response = httpx.Response( + status_code=response.status_code, + headers=response.headers, + content=new_content, + request=response.request, + extensions=response.extensions, + ) + except Exception as e: + logger.error("WorkflowEncodingHook: Failed to decode events: %s", e) + raise - return httpx.Response( - status_code=response.status_code, - headers=response.headers, - content=new_content, - request=response.request, - extensions=response.extensions, - ) - except Exception as e: - logger.error("WorkflowEncodingHook: Failed to decode result: %s", e) - raise + return response diff --git a/src/mistralai/extra/workflows/encoding/payload_encoder.py b/src/mistralai/extra/workflows/encoding/payload_encoder.py index 802ae41b..611f33fa 100644 --- a/src/mistralai/extra/workflows/encoding/payload_encoder.py +++ b/src/mistralai/extra/workflows/encoding/payload_encoder.py @@ -263,6 +263,35 @@ async def encode_payload_content( return data, encoding_options + async def encode_event_payload_content( + self, data: Union[bytes, str], force_full_encryption: bool = False + ) -> tuple[bytes, list[EncodedPayloadOptions]]: + """Encrypt event payload content. + + Unlike encode_payload_content, this only handles encryption (no offloading). + + Args: + data: The payload data to encrypt. + force_full_encryption: Force full encryption regardless of configured mode. + Use for payloads like json_patch that don't support partial encryption. + """ + if isinstance(data, str): + data = data.encode() + + if self.encryption_config is None: + return data, [] + + if force_full_encryption or self.encryption_config.mode == PayloadEncryptionMode.FULL: + encrypted_data = self._encrypt(data) + return encrypted_data, [EncodedPayloadOptions.ENCRYPTED] + + # Partial encryption mode + data, partially_encrypted = await self._partially_encrypt_fields(data) + if partially_encrypted: + return data, [EncodedPayloadOptions.PARTIALLY_ENCRYPTED] + + return data, [] + async def decode_payload_content( self, data: bytes, encoding_options: List[EncodedPayloadOptions] ) -> bytes: @@ -294,6 +323,32 @@ async def decode_payload_content( return data + async def decode_event_payload( + self, payload_data: Dict[str, Any] + ) -> Dict[str, Any]: + """Decrypt an event payload's value if it has encoding_options. + + Args: + payload_data: Dict with 'type', 'value', and 'encoding_options' fields + + Returns: + Dict with decrypted 'value' and empty 'encoding_options' + """ + encoding_options_strs = payload_data.get("encoding_options", []) + if not encoding_options_strs: + return payload_data + + encoding_options = [EncodedPayloadOptions(opt) for opt in encoding_options_strs] + encrypted_bytes = base64.b64decode(payload_data["value"]) + decrypted_bytes = await self.decode_payload_content(encrypted_bytes, encoding_options) + decrypted_value = json.loads(decrypted_bytes) + + return { + "type": payload_data["type"], + "value": decrypted_value, + "encoding_options": [], + } + async def encode_network_input( self, data: Optional[Dict[str, Any]], context: WorkflowContext ) -> NetworkEncodedInput: