From 9c57a41b2cd726dd87f0a98b990b286474d9117c Mon Sep 17 00:00:00 2001 From: Kaiohz Date: Wed, 20 May 2026 17:19:21 +0200 Subject: [PATCH] Add post-validation of LLM structured responses against JSON Schema contract Strip extra fields the LLM invents outside the response_format schema using Pydantic models with extra='ignore'. Adds schema_utils.py for JSON Schema to Pydantic model conversion, validates in DeepAgentRunner._build_response(), and logs warnings for stripped fields. --- src/infrastructure/deepagent/adapter.py | 33 ++- src/infrastructure/deepagent/factory.py | 14 +- src/infrastructure/deepagent/schema_utils.py | 101 +++++++++ .../persistent_registry/adapter.py | 4 +- tests/unit/test_deep_agent_runner.py | 120 ++++++++++ tests/unit/test_persistent_registry.py | 8 +- tests/unit/test_routes.py | 2 +- tests/unit/test_schema_utils.py | 213 ++++++++++++++++++ 8 files changed, 479 insertions(+), 16 deletions(-) create mode 100644 src/infrastructure/deepagent/schema_utils.py create mode 100644 tests/unit/test_schema_utils.py diff --git a/src/infrastructure/deepagent/adapter.py b/src/infrastructure/deepagent/adapter.py index 5d6203c..e5a0847 100644 --- a/src/infrastructure/deepagent/adapter.py +++ b/src/infrastructure/deepagent/adapter.py @@ -5,6 +5,7 @@ from collections.abc import AsyncIterator from langgraph.types import Command +from pydantic import BaseModel from src.domain.entities.message import Message, MessageRole, MessageStatus from src.domain.entities.stream_event import StreamEvent, StreamEventType @@ -16,9 +17,10 @@ class DeepAgentRunner(AgentRunner): - def __init__(self, graph, tracing_provider: TracingProvider | None = None): + def __init__(self, graph, tracing_provider: TracingProvider | None = None, response_format_model: type[BaseModel] | None = None): self._graph = graph self._tracing_provider = tracing_provider + self._response_format_model = response_format_model @staticmethod def _try_parse_json(content: str) -> dict | None: @@ -40,6 +42,31 @@ def _try_parse_json(content: str) -> dict | None: pass return None + def _validate_structured_response(self, data: dict) -> dict: + """Validate structured_response against the response_format model. + + Strips any extra fields not defined in the schema and logs warnings. + """ + try: + validated = self._response_format_model.model_validate(data) + cleaned = validated.model_dump() + self._log_extra_fields(data, cleaned) + return cleaned + except Exception: + logger.warning("Failed to validate structured_response against schema, returning raw data") + return data + + @staticmethod + def _log_extra_fields(original: dict, cleaned: dict) -> None: + """Log any top-level or nested fields that were stripped.""" + for key in original: + if key not in cleaned: + logger.warning("Stripped extra field from structured_response: '%s'", key) + elif isinstance(original[key], dict) and isinstance(cleaned[key], dict): + for sub_key in original[key]: + if sub_key not in cleaned[key]: + logger.warning("Stripped extra nested field: '%s.%s'", key, sub_key) + @staticmethod def _is_nonblank_str(val: object) -> bool: return isinstance(val, str) and val.strip() != "" @@ -114,6 +141,10 @@ def _build_response(self, result: dict, config: dict, thinking: str | None) -> M if structured_response is None: structured_response = self._try_parse_json(last_message.content) + # 4. Validate against response_format schema (strip extra fields) + if structured_response is not None and self._response_format_model is not None: + structured_response = self._validate_structured_response(structured_response) + return Message( role=MessageRole.AI, content=last_message.content, diff --git a/src/infrastructure/deepagent/factory.py b/src/infrastructure/deepagent/factory.py index 7cc3a0e..3aa4813 100644 --- a/src/infrastructure/deepagent/factory.py +++ b/src/infrastructure/deepagent/factory.py @@ -15,6 +15,7 @@ from src.domain.entities.agent_config import AgentConfig, BackendType from src.domain.ports.mcp_tool_loader import McpToolLoader from src.domain.ports.prompt_manager import PromptManager +from src.infrastructure.deepagent.schema_utils import make_validation_model logger = logging.getLogger(__name__) @@ -193,7 +194,7 @@ async def create_agent_from_config( mcp_tool_loader: Optional MCP tool loader for loading remote tools. Returns: - The compiled agent ready for execution. + Tuple of (compiled agent graph, response_format_model or None). """ logger.info("Creating agent '%s' (model=%s)", config.name, config.model) checkpointer = MemorySaver() @@ -238,17 +239,14 @@ async def create_agent_from_config( kwargs["skills"] = config.skills if config.response_format: - # Use tool-based structured output instead of ProviderStrategy to bypass - # Bedrock schema limitations (max 16 anyOf, max 24 optionals, max grammar size). - # The structured_response tool is injected into the agent's tool list so the LLM - # sees it, but we do NOT pass response_format to avoid create_agent forcing - # tool_choice="any", which suppresses intermediate streaming messages. - # The system prompt is augmented with an instruction to use the tool. + response_format_model = make_validation_model(config.response_format) response_tool = _create_response_tool(config.response_format) all_tools = (all_tools or []) + [response_tool] kwargs["tools"] = all_tools current_prompt = kwargs.get("system_prompt", "") kwargs["system_prompt"] = (current_prompt or "") + STRUCTURED_OUTPUT_INSTRUCTION + else: + response_format_model = None subagents = await _resolve_subagents(config, mcp_tool_loader, prompt_manager) if subagents: @@ -260,7 +258,7 @@ async def create_agent_from_config( logger.error(f"Error creating agent '{config.name}': {e}") raise logger.info("Agent '%s' created successfully", config.name) - return graph + return graph, response_format_model # helper to get system_prompt from Phoenix diff --git a/src/infrastructure/deepagent/schema_utils.py b/src/infrastructure/deepagent/schema_utils.py new file mode 100644 index 0000000..35fba14 --- /dev/null +++ b/src/infrastructure/deepagent/schema_utils.py @@ -0,0 +1,101 @@ +import hashlib +import json +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field, create_model + + +def schema_to_pydantic_model(schema: dict[str, Any], model_name: str = "DynamicModel") -> type[BaseModel]: + """Convert a JSON Schema dict to a Pydantic BaseModel with extra='ignore'. + + Recursively builds nested models for object properties, stripping any + fields not present in the schema when model_validate is called. + """ + return _build_model(schema, model_name) + + +def _build_model(schema: dict[str, Any], name: str) -> type[BaseModel]: + schema_type = schema.get("type", "object") + + if schema_type == "object": + return _build_object_model(schema, name) + if schema_type == "array": + return _build_array_model(schema, name) + + primitives = {"string": str, "number": float, "integer": int, "boolean": bool} + return primitives.get(schema_type, str) + + +def _build_object_model(schema: dict[str, Any], name: str) -> type[BaseModel]: + properties = schema.get("properties", {}) + required = set(schema.get("required", [])) + + if not properties: + return create_model( + name, + __config__=ConfigDict(extra="ignore"), + **{_sanitize(k): (dict | None, Field(default=None)) for k in properties}, + ) + + field_defs: dict[str, Any] = {} + for prop_name, prop_schema in properties.items(): + python_name = _sanitize(prop_name) + prop_type = prop_schema.get("type", "string") + + if prop_type == "object": + nested_name = f"{name}_{_sanitize(prop_name).capitalize()}" + nested_model = _build_object_model(prop_schema, nested_name) + if prop_name in required: + field_defs[python_name] = (nested_model, Field(description=prop_schema.get("description", ""))) + else: + field_defs[python_name] = (nested_model | None, Field(default=None, description=prop_schema.get("description", ""))) + elif prop_type == "array": + items_schema = prop_schema.get("items", {}) + items_type = items_schema.get("type", "string") + if items_type == "object": + nested_name = f"{name}_{_sanitize(prop_name).capitalize()}Item" + nested_model = _build_object_model(items_schema, nested_name) + list_type = list[nested_model] + else: + primitives = {"string": str, "number": float, "integer": int, "boolean": bool} + list_type = list[primitives.get(items_type, str)] + if prop_name in required: + field_defs[python_name] = (list_type, Field(default_factory=list, description=prop_schema.get("description", ""))) + else: + field_defs[python_name] = (list_type | None, Field(default=None, description=prop_schema.get("description", ""))) + else: + primitives = {"string": str, "number": float, "integer": int, "boolean": bool} + python_type = primitives.get(prop_type, str) + if prop_name in required: + field_defs[python_name] = (python_type, Field(description=prop_schema.get("description", ""))) + else: + field_defs[python_name] = (python_type | None, Field(default=None, description=prop_schema.get("description", ""))) + + return create_model(name, __config__=ConfigDict(extra="ignore"), **field_defs) + + +def _build_array_model(schema: dict[str, Any], name: str) -> type: + items_schema = schema.get("items", {}) + items_type = items_schema.get("type", "string") + if items_type == "object": + nested_model = _build_object_model(items_schema, f"{name}Item") + return list[nested_model] + primitives = {"string": str, "number": float, "integer": int, "boolean": bool} + return list[primitives.get(items_type, str)] + + +def _sanitize(name: str) -> str: + sanitized = name.replace("-", "_").replace(" ", "_") + if sanitized[0].isdigit(): + sanitized = f"field_{sanitized}" + return sanitized + + +def make_validation_model(response_format: dict[str, Any]) -> type[BaseModel]: + """Create a Pydantic validation model from an agent's response_format schema. + + The resulting model uses extra='ignore' so any fields not in the schema + are silently stripped on model_validate. + """ + schema_hash = hashlib.sha256(json.dumps(response_format, sort_keys=True).encode()).hexdigest()[:8] + return schema_to_pydantic_model(response_format, f"ResponseFormat_{schema_hash}") diff --git a/src/infrastructure/persistent_registry/adapter.py b/src/infrastructure/persistent_registry/adapter.py index e6d1982..b24c585 100644 --- a/src/infrastructure/persistent_registry/adapter.py +++ b/src/infrastructure/persistent_registry/adapter.py @@ -59,8 +59,8 @@ async def get_runner(self, agent_name: str) -> AgentRunner: logger.info("Building agent '%s' from persistent store", agent_name) yaml_content = await self._config_store.get(agent_name) config = self._config_loader.load_from_string(yaml_content) - graph = await create_agent_from_config(config, self._mcp_tool_loader, self._prompt_manager) - runner = DeepAgentRunner(graph, tracing_provider=self._tracing_provider) + graph, response_format_model = await create_agent_from_config(config, self._mcp_tool_loader, self._prompt_manager) + runner = DeepAgentRunner(graph, tracing_provider=self._tracing_provider, response_format_model=response_format_model) self._runners[agent_name] = runner logger.info("Agent '%s' ready and cached", agent_name) return runner diff --git a/tests/unit/test_deep_agent_runner.py b/tests/unit/test_deep_agent_runner.py index d800ac7..e47f08d 100644 --- a/tests/unit/test_deep_agent_runner.py +++ b/tests/unit/test_deep_agent_runner.py @@ -293,3 +293,123 @@ async def test_build_response_no_structured_response(self): result = await runner.invoke("thread-1", "hi") assert result.structured_response is None + + # --- Post-validation tests --- + + async def test_validate_structured_response_strips_extra_top_level_fields(self): + """Extra top-level fields invented by the LLM are stripped.""" + from src.infrastructure.deepagent.schema_utils import make_validation_model + + schema = { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, + "required": ["name"], + } + model = make_validation_model(schema) + mock_msg = MagicMock() + mock_msg.content = "Result" + mock_msg.tool_calls = None + graph = _make_graph([mock_msg]) + graph.ainvoke.return_value = { + "messages": [mock_msg], + "structured_response": {"name": "Alice", "age": 30, "terraceArea": 50, "parkingSpaces": 2}, + } + + runner = DeepAgentRunner(graph, response_format_model=model) + result = await runner.invoke("thread-1", "analyze") + + assert result.structured_response == {"name": "Alice", "age": 30} + assert "terraceArea" not in result.structured_response + assert "parkingSpaces" not in result.structured_response + + async def test_validate_structured_response_strips_nested_extra_fields(self): + """Extra nested fields invented by the LLM are stripped.""" + from src.infrastructure.deepagent.schema_utils import make_validation_model + + schema = { + "type": "object", + "properties": { + "building": { + "type": "object", + "properties": {"floors": {"type": "integer"}}, + "required": ["floors"], + } + }, + "required": ["building"], + } + model = make_validation_model(schema) + mock_msg = MagicMock() + mock_msg.content = "Result" + mock_msg.tool_calls = None + graph = _make_graph([mock_msg]) + graph.ainvoke.return_value = { + "messages": [mock_msg], + "structured_response": {"building": {"floors": 3, "rooftop": True}}, + } + + runner = DeepAgentRunner(graph, response_format_model=model) + result = await runner.invoke("thread-1", "analyze") + + assert result.structured_response == {"building": {"floors": 3}} + assert "rooftop" not in result.structured_response["building"] + + async def test_validate_structured_response_no_model_returns_raw(self): + """When no response_format_model is set, data passes through unmodified.""" + mock_msg = MagicMock() + mock_msg.content = "Result" + mock_msg.tool_calls = None + graph = _make_graph([mock_msg]) + graph.ainvoke.return_value = { + "messages": [mock_msg], + "structured_response": {"name": "test", "extra": True}, + } + + runner = DeepAgentRunner(graph, response_format_model=None) + result = await runner.invoke("thread-1", "analyze") + + assert result.structured_response == {"name": "test", "extra": True} + + def test_log_extra_fields_logs_top_level(self, caplog): + """_log_extra_fields logs warnings for stripped top-level keys.""" + import logging + + with caplog.at_level(logging.WARNING): + DeepAgentRunner._log_extra_fields( + {"name": "a", "invented": 1}, + {"name": "a"}, + ) + assert "invented" in caplog.text + + def test_log_extra_fields_logs_nested(self, caplog): + """_log_extra_fields logs warnings for stripped nested keys.""" + import logging + + with caplog.at_level(logging.WARNING): + DeepAgentRunner._log_extra_fields( + {"building": {"floors": 3, "bogus": 1}}, + {"building": {"floors": 3}}, + ) + assert "building.bogus" in caplog.text + + async def test_validate_structured_response_from_tool_call(self): + """structured_response extracted from tool_calls is also validated.""" + from src.infrastructure.deepagent.schema_utils import make_validation_model + + schema = { + "type": "object", + "properties": {"summary": {"type": "string"}}, + "required": ["summary"], + } + model = make_validation_model(schema) + ai_msg = MagicMock() + ai_msg.content = "Done" + ai_msg.tool_calls = [ + {"name": "structured_response", "args": {"summary": "ok", "hallucinated": 99}, "id": "tc-1"} + ] + graph = _make_graph([ai_msg]) + + runner = DeepAgentRunner(graph, response_format_model=model) + result = await runner.invoke("thread-1", "summarize") + + assert result.structured_response == {"summary": "ok"} + assert "hallucinated" not in result.structured_response diff --git a/tests/unit/test_persistent_registry.py b/tests/unit/test_persistent_registry.py index 006484f..71ed175 100644 --- a/tests/unit/test_persistent_registry.py +++ b/tests/unit/test_persistent_registry.py @@ -67,7 +67,7 @@ async def test_get_runner_loads_from_store(self, mock_runner_cls, mock_create, r """get_runner should fetch YAML from MinIO, parse it, create agent, cache runner.""" mock_store.get.return_value = VALID_YAML mock_graph = MagicMock() - mock_create.return_value = mock_graph + mock_create.return_value = (mock_graph, None) mock_runner_instance = MagicMock() mock_runner_cls.return_value = mock_runner_instance @@ -85,7 +85,7 @@ async def test_get_runner_loads_from_store(self, mock_runner_cls, mock_create, r async def test_get_runner_cache_hit(self, mock_runner_cls, mock_create, registry, mock_store): """Second call should return cached runner without fetching from store again.""" mock_store.get.return_value = VALID_YAML - mock_create.return_value = MagicMock() + mock_create.return_value = (MagicMock(), None) mock_runner_instance = MagicMock() mock_runner_cls.return_value = mock_runner_instance @@ -124,7 +124,7 @@ async def test_list_agents_queries_repository(self, registry, mock_repository): async def test_invalidate_clears_cache(self, mock_runner_cls, mock_create, registry, mock_store): """After invalidate, next get_runner should re-fetch from store.""" mock_store.get.return_value = VALID_YAML - mock_create.return_value = MagicMock() + mock_create.return_value = (MagicMock(), None) runner_a = MagicMock() runner_b = MagicMock() mock_runner_cls.side_effect = [runner_a, runner_b] @@ -148,7 +148,7 @@ async def test_invalidate_clears_cache(self, mock_runner_cls, mock_create, regis async def test_close_clears_all_runners(self, mock_runner_cls, mock_create, registry, mock_store): """close should empty the runners cache.""" mock_store.get.return_value = VALID_YAML - mock_create.return_value = MagicMock() + mock_create.return_value = (MagicMock(), None) mock_runner_cls.return_value = MagicMock() await registry.get_runner("test-agent") diff --git a/tests/unit/test_routes.py b/tests/unit/test_routes.py index 53637b6..958c0df 100644 --- a/tests/unit/test_routes.py +++ b/tests/unit/test_routes.py @@ -146,7 +146,7 @@ def _wire_dependencies( patch( "src.infrastructure.persistent_registry.adapter.create_agent_from_config", new_callable=AsyncMock, - return_value=MagicMock(), + return_value=(MagicMock(), None), ), patch( "src.infrastructure.persistent_registry.adapter.DeepAgentRunner", diff --git a/tests/unit/test_schema_utils.py b/tests/unit/test_schema_utils.py new file mode 100644 index 0000000..21d7837 --- /dev/null +++ b/tests/unit/test_schema_utils.py @@ -0,0 +1,213 @@ +"""Tests for schema_utils: schema_to_pydantic_model, make_validation_model.""" + +from pydantic import BaseModel + +from src.infrastructure.deepagent.schema_utils import ( + _build_array_model, + _sanitize, + make_validation_model, + schema_to_pydantic_model, +) + + +class TestSanitize: + def test_hyphens(self): + assert _sanitize("my-field") == "my_field" + + def test_spaces(self): + assert _sanitize("my field") == "my_field" + + def test_leading_digit(self): + assert _sanitize("1field") == "field_1field" + + def test_no_change(self): + assert _sanitize("normalName") == "normalName" + + +class TestSchemaToPydanticModel: + def test_flat_object(self): + schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name"], + } + model = schema_to_pydantic_model(schema, "FlatModel") + assert issubclass(model, BaseModel) + instance = model.model_validate({"name": "Alice", "age": 30}) + assert instance.name == "Alice" + assert instance.age == 30 + + def test_flat_object_strips_extra_fields(self): + schema = { + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + } + model = schema_to_pydantic_model(schema, "StripModel") + instance = model.model_validate({"name": "Bob", "invented": "extra"}) + assert instance.name == "Bob" + dumped = instance.model_dump() + assert "invented" not in dumped + + def test_optional_fields_default_none(self): + schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "score": {"type": "number"}, + }, + "required": ["name"], + } + model = schema_to_pydantic_model(schema, "OptModel") + instance = model.model_validate({"name": "test"}) + assert instance.name == "test" + assert instance.score is None + + def test_nested_object(self): + schema = { + "type": "object", + "properties": { + "address": { + "type": "object", + "properties": { + "city": {"type": "string"}, + "zip": {"type": "integer"}, + }, + "required": ["city"], + } + }, + "required": ["address"], + } + model = schema_to_pydantic_model(schema, "NestedModel") + instance = model.model_validate({"address": {"city": "Paris", "zip": 75001}}) + assert instance.address.city == "Paris" + assert instance.address.zip == 75001 + + def test_nested_object_strips_extra_fields(self): + schema = { + "type": "object", + "properties": { + "address": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + } + }, + "required": ["address"], + } + model = schema_to_pydantic_model(schema, "NestedStripModel") + instance = model.model_validate({"address": {"city": "Lyon", "bogus": 99}}) + dumped = instance.model_dump() + assert "bogus" not in dumped["address"] + + def test_array_of_primitives(self): + schema = { + "type": "object", + "properties": { + "tags": { + "type": "array", + "items": {"type": "string"}, + } + }, + "required": ["tags"], + } + model = schema_to_pydantic_model(schema, "ArrayModel") + instance = model.model_validate({"tags": ["a", "b"]}) + assert instance.tags == ["a", "b"] + + def test_array_of_objects(self): + schema = { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + }, + } + }, + "required": ["items"], + } + model = schema_to_pydantic_model(schema, "ObjArrayModel") + instance = model.model_validate({"items": [{"name": "x"}, {"name": "y"}]}) + assert len(instance.items) == 2 + + def test_primitives(self): + assert schema_to_pydantic_model({"type": "string"}, "M") is str + assert schema_to_pydantic_model({"type": "integer"}, "M") is int + assert schema_to_pydantic_model({"type": "number"}, "M") is float + assert schema_to_pydantic_model({"type": "boolean"}, "M") is bool + + def test_empty_properties(self): + schema = {"type": "object", "properties": {}, "required": []} + model = schema_to_pydantic_model(schema, "EmptyModel") + instance = model.model_validate({"anything": "here"}) + dumped = instance.model_dump() + assert "anything" not in dumped + + def test_sanitize_field_names(self): + schema = { + "type": "object", + "properties": { + "my-field": {"type": "string"}, + "2ndField": {"type": "integer"}, + }, + "required": ["my-field"], + } + model = schema_to_pydantic_model(schema, "SanitizeModel") + instance = model.model_validate({"my_field": "val", "field_2ndField": 5}) + assert instance.my_field == "val" + assert instance.field_2ndField == 5 + + +class TestBuildArrayModel: + def test_array_of_objects(self): + schema = { + "type": "array", + "items": { + "type": "object", + "properties": {"id": {"type": "integer"}}, + "required": ["id"], + }, + } + result = _build_array_model(schema, "TestArr") + assert hasattr(result, "__origin__") + + def test_array_of_strings(self): + schema = {"type": "array", "items": {"type": "string"}} + result = _build_array_model(schema, "TestArr") + assert hasattr(result, "__origin__") + + +class TestMakeValidationModel: + def test_creates_model_with_hash_name(self): + schema = { + "type": "object", + "properties": {"value": {"type": "string"}}, + "required": ["value"], + } + model = make_validation_model(schema) + assert issubclass(model, BaseModel) + assert model.__name__.startswith("ResponseFormat_") + + def test_same_schema_produces_same_hash(self): + schema = { + "type": "object", + "properties": {"value": {"type": "string"}}, + "required": ["value"], + } + model1 = make_validation_model(schema) + model2 = make_validation_model(schema) + assert model1.__name__ == model2.__name__ + + def test_different_schema_produces_different_hash(self): + schema1 = {"type": "object", "properties": {"a": {"type": "string"}}, "required": ["a"]} + schema2 = {"type": "object", "properties": {"b": {"type": "integer"}}, "required": ["b"]} + model1 = make_validation_model(schema1) + model2 = make_validation_model(schema2) + assert model1.__name__ != model2.__name__