Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 210 additions & 37 deletions src/mistralai/client/_hooks/workflow_encoding_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import re
import uuid
import weakref
from typing import Any, Coroutine, Dict, Optional, TypeVar, Union
from typing import Any, AsyncIterator, Coroutine, Dict, Optional, TypeVar, Union

import httpx
from httpx._types import AsyncByteStream

from .types import (
AfterSuccessContext,
Expand All @@ -28,9 +29,7 @@


class _WorkflowEncodingConfig:
def __init__(
self, payload_encoder: PayloadEncoder, namespace: str
) -> None:
def __init__(self, payload_encoder: PayloadEncoder, namespace: str) -> None:
self.payload_encoder = payload_encoder
self.namespace = namespace

Expand Down Expand Up @@ -64,7 +63,9 @@ def configure_workflow_encoding(
)


def _get_encoding_config(sdk_config: SDKConfiguration) -> Optional[_WorkflowEncodingConfig]:
def _get_encoding_config(
sdk_config: SDKConfiguration,
) -> Optional[_WorkflowEncodingConfig]:
"""Get workflow encoding config for a client."""
config_id = getattr(sdk_config, _ENCODING_CONFIG_ID_ATTR, None)
if config_id is None:
Expand Down Expand Up @@ -100,9 +101,31 @@ def _get_encoding_config(sdk_config: SDKConfiguration) -> Optional[_WorkflowEnco
"update_workflow_execution_v1_workflows_executions__execution_id__updates_post",
}

# Operations that return event data that may need decryption
OPERATIONS_DECODE_EVENTS = {
"get_workflow_events_v1_workflows_events_list_get",
}

# Streaming operations that return SSE event data that may need decryption
OPERATIONS_DECODE_EVENTS_STREAM = {
"get_stream_events_v1_workflows_events_stream_get",
"stream_v1_workflows_executions__execution_id__stream_get",
}

SCHEDULE_CORRELATION_ID_PLACEHOLDER = "__scheduled_workflow__"


def _is_payload_type(value: Any) -> bool:
"""Check if a value is a JSONPayload or JSONPatchPayload by its structure.

Payload types have: {"type": "json" | "json_patch", "value": ...}
"""
if not isinstance(value, dict):
return False
payload_type = value.get("type")
return payload_type in ("json", "json_patch") and "value" in value


_T = TypeVar("_T")


Expand Down Expand Up @@ -143,6 +166,127 @@ def _extract_execution_id_from_body(body: Dict[str, Any]) -> Optional[str]:
return body.get("execution_id")


async def _decrypt_event_attributes(
attributes: Dict[str, Any],
payload_encoder: PayloadEncoder,
) -> Dict[str, Any]:
"""Decrypt payload fields in event attributes."""
for field_name, field_value in attributes.items():
if not _is_payload_type(field_value):
continue

# Check if it has encoding_options (meaning it's encrypted)
if not field_value.get("encoding_options"):
continue

# Decrypt the payload
decrypted = await payload_encoder.decode_event_payload(field_value)
attributes[field_name] = decrypted

return attributes


async def _decrypt_events_in_response(
body: Dict[str, Any],
payload_encoder: PayloadEncoder,
) -> Dict[str, Any]:
"""Decrypt payload fields in events within a response body."""
events = body.get("events", [])
if not events:
return body

for event in events:
attributes = event.get("attributes")
if isinstance(attributes, dict):
event["attributes"] = await _decrypt_event_attributes(
attributes, payload_encoder
)

return body


def _decrypt_sse_line(line: bytes, payload_encoder: PayloadEncoder) -> bytes:
"""Decrypt event payloads in an SSE data line."""
if not line.startswith(b"data:"):
return line

try:
data_part = line[5:].strip()
if not data_part:
return line

event_wrapper = json.loads(data_part)
data = event_wrapper.get("data")
if not isinstance(data, dict):
return line

attributes = data.get("attributes")
if not isinstance(attributes, dict):
return line

# Decrypt in place - _decrypt_event_attributes modifies attributes dict
_run_async(_decrypt_event_attributes(attributes, payload_encoder))

return b"data: " + json.dumps(event_wrapper).encode("utf-8")
except (json.JSONDecodeError, Exception) as e:
logger.debug("SSE line decryption failed: %s", e)
return line


class _DecryptingAsyncByteStream(AsyncByteStream):
"""Async byte stream wrapper that decrypts SSE event payloads."""

def __init__(self, original_stream: Any, payload_encoder: PayloadEncoder):
self._original = original_stream
self._payload_encoder = payload_encoder
self._buffer = b""

async def __aiter__(self) -> AsyncIterator[bytes]:
async for chunk in self._original:
for processed in self._process_chunk(chunk):
yield processed
# Flush remaining buffer
if self._buffer:
yield _decrypt_sse_line(self._buffer, self._payload_encoder)

def _process_chunk(self, chunk: bytes):
self._buffer += chunk
lines = self._buffer.split(b"\n")
# Keep last incomplete line in buffer
self._buffer = lines[-1]
for line in lines[:-1]:
yield _decrypt_sse_line(line, self._payload_encoder) + b"\n"

async def aclose(self) -> None:
if hasattr(self._original, "aclose"):
await self._original.aclose()


def _wrap_sse_response_with_decryption(
response: httpx.Response,
payload_encoder: PayloadEncoder,
) -> httpx.Response:
"""Wrap an SSE response to decrypt event payloads as they stream.

Creates a new response with a custom stream that decrypts payloads on-the-fly.
"""
# Get the original stream from the response
original_stream = response.stream

# Create wrapped stream
decrypting_stream = _DecryptingAsyncByteStream(original_stream, payload_encoder)

# Create new response with wrapped stream
# Use internal _content to avoid reading stream
new_response = httpx.Response(
status_code=response.status_code,
headers=response.headers,
stream=decrypting_stream,
request=response.request,
extensions=response.extensions,
)

return new_response


class WorkflowEncodingHook(BeforeRequestHook, AfterSuccessHook):
Expand Down Expand Up @@ -203,12 +347,10 @@ def before_request(
execution_id=execution_id,
)

logger.debug(
"WorkflowEncodingHook: Encoding input for %s", hook_ctx.operation_id
)

encoded_input = _run_async(
encoding_config.payload_encoder.encode_network_input(input_data, context)
encoding_config.payload_encoder.encode_network_input(
input_data, context
)
)

# Update body based on operation type:
Expand Down Expand Up @@ -241,40 +383,71 @@ def after_success(
hook_ctx: AfterSuccessContext,
response: httpx.Response,
) -> Union[httpx.Response, Exception]:
"""Intercept responses to decode workflow result payloads."""
"""Intercept responses to decode workflow result payloads and event payloads."""
encoding_config = _get_encoding_config(hook_ctx.config)
if not encoding_config:
return response

if hook_ctx.operation_id not in OPERATIONS_DECODE_RESULT:
content_type = response.headers.get("content-type", "")

# Handle SSE stream decryption
if hook_ctx.operation_id in OPERATIONS_DECODE_EVENTS_STREAM:
if "text/event-stream" in content_type:
return _wrap_sse_response_with_decryption(
response, encoding_config.payload_encoder
)
return response

content_type = response.headers.get("content-type", "")
if "application/json" not in content_type:
return response

try:
body = json.loads(response.content)
result = body.get("result")
if result is None or not encoding_config.payload_encoder.check_is_payload_encoded(result):
return response

logger.debug(
"WorkflowEncodingHook: Decoding result for %s", hook_ctx.operation_id
)

decoded_result = _run_async(encoding_config.payload_encoder.decode_network_result(result))

body["result"] = decoded_result
new_content = json.dumps(body).encode("utf-8")
# Handle workflow result decoding
if hook_ctx.operation_id in OPERATIONS_DECODE_RESULT:
try:
body = json.loads(response.content)
result = body.get("result")
if (
result is not None
and encoding_config.payload_encoder.check_is_payload_encoded(
result
)
):
decoded_result = _run_async(
encoding_config.payload_encoder.decode_network_result(result)
)

body["result"] = decoded_result
new_content = json.dumps(body).encode("utf-8")

response = httpx.Response(
status_code=response.status_code,
headers=response.headers,
content=new_content,
request=response.request,
extensions=response.extensions,
)
except Exception as e:
logger.error("WorkflowEncodingHook: Failed to decode result: %s", e)
raise

# Handle event payload decoding
elif hook_ctx.operation_id in OPERATIONS_DECODE_EVENTS:
try:
body = json.loads(response.content)
body = _run_async(
_decrypt_events_in_response(body, encoding_config.payload_encoder)
)
new_content = json.dumps(body).encode("utf-8")

response = httpx.Response(
status_code=response.status_code,
headers=response.headers,
content=new_content,
request=response.request,
extensions=response.extensions,
)
except Exception as e:
logger.error("WorkflowEncodingHook: Failed to decode events: %s", e)
raise

return httpx.Response(
status_code=response.status_code,
headers=response.headers,
content=new_content,
request=response.request,
extensions=response.extensions,
)
except Exception as e:
logger.error("WorkflowEncodingHook: Failed to decode result: %s", e)
raise
return response
55 changes: 55 additions & 0 deletions src/mistralai/extra/workflows/encoding/payload_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,35 @@ async def encode_payload_content(

return data, encoding_options

async def encode_event_payload_content(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nit: here the name is encode_event_payload_content, not consistent with decode_event_payload, can we unify the 2 ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it's consistent, but you must compare it to decode_network_result, it actually take a whole event payload, extract the encoding options and extracting payload value, open to renaming suggestion if any !

self, data: Union[bytes, str], force_full_encryption: bool = False
) -> tuple[bytes, list[EncodedPayloadOptions]]:
"""Encrypt event payload content.

Unlike encode_payload_content, this only handles encryption (no offloading).

Args:
data: The payload data to encrypt.
force_full_encryption: Force full encryption regardless of configured mode.
Use for payloads like json_patch that don't support partial encryption.
"""
if isinstance(data, str):
data = data.encode()

if self.encryption_config is None:
return data, []

if force_full_encryption or self.encryption_config.mode == PayloadEncryptionMode.FULL:
encrypted_data = self._encrypt(data)
return encrypted_data, [EncodedPayloadOptions.ENCRYPTED]

# Partial encryption mode
data, partially_encrypted = await self._partially_encrypt_fields(data)
if partially_encrypted:
return data, [EncodedPayloadOptions.PARTIALLY_ENCRYPTED]

return data, []

async def decode_payload_content(
self, data: bytes, encoding_options: List[EncodedPayloadOptions]
) -> bytes:
Expand Down Expand Up @@ -294,6 +323,32 @@ async def decode_payload_content(

return data

async def decode_event_payload(
self, payload_data: Dict[str, Any]
) -> Dict[str, Any]:
"""Decrypt an event payload's value if it has encoding_options.

Args:
payload_data: Dict with 'type', 'value', and 'encoding_options' fields

Returns:
Dict with decrypted 'value' and empty 'encoding_options'
"""
encoding_options_strs = payload_data.get("encoding_options", [])
if not encoding_options_strs:
return payload_data

encoding_options = [EncodedPayloadOptions(opt) for opt in encoding_options_strs]
encrypted_bytes = base64.b64decode(payload_data["value"])
decrypted_bytes = await self.decode_payload_content(encrypted_bytes, encoding_options)
decrypted_value = json.loads(decrypted_bytes)

return {
"type": payload_data["type"],
"value": decrypted_value,
"encoding_options": [],
}

async def encode_network_input(
self, data: Optional[Dict[str, Any]], context: WorkflowContext
) -> NetworkEncodedInput:
Expand Down
Loading