From 0de5fed08de994b9ad3640e99e66113ce5b3bab3 Mon Sep 17 00:00:00 2001 From: netan-sa Date: Tue, 2 Jun 2026 23:54:28 +0200 Subject: [PATCH] Add: add blocking gates for langgraph_stategraph and langchain_executor --- sdk/adrian/__init__.py | 443 +++++++++++++++++++++++++++++++---------- sdk/adrian/ws.py | 12 ++ 2 files changed, 350 insertions(+), 105 deletions(-) diff --git a/sdk/adrian/__init__.py b/sdk/adrian/__init__.py index 0b1ed09..1d312ab 100644 --- a/sdk/adrian/__init__.py +++ b/sdk/adrian/__init__.py @@ -512,6 +512,23 @@ def _inject_callbacks(config: Any) -> Any: # noqa: ANN401 # ------------------------------------------------------------------ +def _warn_unsupported_frameworks() -> None: + """Warn when raw openai SDK is used without a patchable framework.""" + import importlib.util + + if importlib.util.find_spec("openai") is not None: + # Only warn if openai is present without langchain + if importlib.util.find_spec("langchain_core") is None: + logger.warning( + "Detected raw openai SDK without LangChain/LangGraph. " + "Adrian's pre-execution block gate requires a supported " + "framework (LangGraph ToolNode or LangChain AgentExecutor). " + "Tool calls from raw openai loops are OBSERVED but NOT " + "pre-blocked in MODE_BLOCK. " + "See https://docs.adrian.secureagentics.ai/supported-frameworks" + ) + + def _auto_instrument_langchain() -> None: """Apply all monkey-patches to LangChain / LangGraph.""" try: @@ -520,6 +537,8 @@ def _auto_instrument_langchain() -> None: _patch_chat_model() _patch_langgraph() _patch_tool_node() + _patch_agent_executor() + _warn_unsupported_frameworks() logger.debug("LangChain auto-instrumentation applied") except ImportError: logger.debug("LangChain not found, skipping auto-instrumentation") @@ -531,12 +550,14 @@ def _auto_instrument_langchain() -> None: def _patch_runnable() -> None: - """Patch ``Runnable.invoke`` / ``ainvoke`` to inject callbacks.""" + """Patch ``Runnable.invoke`` / ``ainvoke`` / ``astream`` / ``stream``.""" if getattr(Runnable, "_adrian_patched", False): return original_invoke = Runnable.invoke original_ainvoke = Runnable.ainvoke + original_astream = Runnable.astream + original_stream = Runnable.stream def patched_invoke( self: Any, # noqa: ANN401 @@ -544,9 +565,7 @@ def patched_invoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into sync Runnable call.""" config = _inject_callbacks(config) - return original_invoke(self, input, config, **kwargs) async def patched_ainvoke( @@ -555,15 +574,35 @@ async def patched_ainvoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into async Runnable call.""" config = _inject_callbacks(config) - return await original_ainvoke(self, input, config, **kwargs) + async def patched_astream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + """AgentExecutor calls astream on the agent chain by default.""" + config = _inject_callbacks(config) + async for chunk in original_astream(self, input, config, **kwargs): + yield chunk + + def patched_stream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + yield from original_stream(self, input, config, **kwargs) + Runnable.invoke = patched_invoke # type: ignore[assignment] Runnable.ainvoke = patched_ainvoke # type: ignore[assignment] + Runnable.astream = patched_astream # type: ignore[assignment] + Runnable.stream = patched_stream # type: ignore[assignment] Runnable._adrian_patched = True # type: ignore[attr-defined] - logger.debug("Patched Runnable.invoke / ainvoke") + logger.debug("Patched Runnable.invoke / ainvoke / astream / stream") # --- 2. CallbackManager --- @@ -634,12 +673,14 @@ def patched_configure( def _patch_chat_model() -> None: - """Patch ``BaseChatModel.invoke`` / ``ainvoke`` to inject callbacks.""" + """Patch ``BaseChatModel.invoke`` / ``ainvoke`` / ``astream`` / ``stream``.""" if getattr(BaseChatModel, "_adrian_chat_model_patched", False): return original_invoke = BaseChatModel.invoke original_ainvoke = BaseChatModel.ainvoke + original_astream = BaseChatModel.astream + original_stream = BaseChatModel.stream def patched_invoke( self: Any, # noqa: ANN401 @@ -647,9 +688,7 @@ def patched_invoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into sync chat model call.""" config = _inject_callbacks(config) - return original_invoke(self, input, config=config, **kwargs) async def patched_ainvoke( @@ -658,15 +697,34 @@ async def patched_ainvoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into async chat model call.""" config = _inject_callbacks(config) - return await original_ainvoke(self, input, config=config, **kwargs) + async def patched_astream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + async for chunk in original_astream(self, input, config=config, **kwargs): + yield chunk + + def patched_stream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + yield from original_stream(self, input, config=config, **kwargs) + BaseChatModel.invoke = patched_invoke # type: ignore[assignment] BaseChatModel.ainvoke = patched_ainvoke # type: ignore[assignment] + BaseChatModel.astream = patched_astream # type: ignore[assignment] + BaseChatModel.stream = patched_stream # type: ignore[assignment] BaseChatModel._adrian_chat_model_patched = True # type: ignore[attr-defined] - logger.debug("Patched BaseChatModel.invoke / ainvoke") + logger.debug("Patched BaseChatModel.invoke / ainvoke / astream / stream") # --- 4. LangGraph Pregel --- @@ -761,26 +819,22 @@ async def patched_astream( def _extract_tool_calls( - state: dict[str, Any] | list[BaseMessage], + state: dict[str, Any] | list[BaseMessage] | Any, ) -> list[dict[str, str]]: - """Extract tool_calls from the last AIMessage in ToolNode state. - - LangGraph's ``ToolNode.ainvoke`` accepts two input shapes: a state - dict whose ``"messages"`` key holds the message list, or a bare - list of messages. We handle both. - - Args: - state: The ToolNode input, a state dict with a ``"messages"`` - key, or a direct list of ``BaseMessage`` instances. + """Extract tool_calls from ToolNode input (state dict, message list, or per-tool-call dict).""" + if isinstance(state, dict) and "tool_call" in state: + tc = state["tool_call"] + if isinstance(tc, dict) and tc.get("id"): + return [tc] + if hasattr(tc, "id") and tc.id: + return [{"id": tc.id, "name": getattr(tc, "name", ""), "args": getattr(tc, "args", {})}] - Returns: - List of tool call dicts from the most recent ``AIMessage``, or - an empty list when none is found. - """ if isinstance(state, dict): messages = list(state.get("messages") or []) # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType] - else: + elif isinstance(state, list): messages = list(state) + else: + return [] for msg in reversed(messages): if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None): @@ -792,10 +846,7 @@ def _extract_tool_calls( def _should_halt(verdict: pb.Verdict) -> bool: """Decide whether a verdict should halt tool execution. - HITL resolutions override everything: ``continue_execution=False`` - means halt, ``True`` means continue. Otherwise the per-MAD policy - bool is the sole scope authority, if the verdict's tier is - in-scope, halt; if not, continue. + HITL resolutions override the per-MAD policy scope check. """ if verdict.HasField("hitl"): return not verdict.hitl.continue_execution @@ -814,14 +865,7 @@ def _should_halt(verdict: pb.Verdict) -> bool: def _build_blocked_response( tool_calls: list[dict[str, str]], ) -> dict[str, list[ToolMessage]]: - """Build synthetic ToolMessage responses for blocked tool calls. - - Args: - tool_calls: List of tool call dicts extracted from the AIMessage. - - Returns: - Dict in the format ToolNode expects. - """ + """Build synthetic ToolMessage responses for blocked tool calls.""" blocked_messages: list[ToolMessage] = [ ToolMessage( content="[BLOCKED by security policy]", @@ -834,13 +878,67 @@ def _build_blocked_response( return {"messages": blocked_messages} +async def _adrian_tool_gate( + input: Any, # noqa: A002, ANN401 +) -> tuple[str, dict[str, Any] | None]: + """Pre-execution verdict gate. Returns ("halt", response), ("proceed", None), or ("skip", None).""" + ws = _ws_client + + if ws is None: + return ("skip", None) + + if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] + try: + await asyncio.wait_for( + ws._login_ack_received.wait(), # pyright: ignore[reportPrivateUsage] + timeout=5.0, + ) + except TimeoutError: + logger.warning( + "ToolNode: LoginAck not received within 5s; halting " + "(refusing to run a tool without a verified policy)" + ) + return ("halt", _build_blocked_response(_extract_tool_calls(input))) + + if not ws.policy_active(): + return ("skip", None) + + tool_calls = _extract_tool_calls(input) + tool_call_id = next( + (tc.get("id") for tc in tool_calls if tc.get("id")), + None, + ) + + if not tool_call_id: + return ("skip", None) + + cfg = _get_config() + timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) + + verdict = await ws.wait_for_tool_call_verdict(tool_call_id, timeout) + + if verdict is None: + logger.warning( + "verdict timeout for tool_call_id=%s, fail-open", + tool_call_id, + ) + return ("skip", None) + + if _should_halt(verdict): + logger.warning( + "halting tool execution for event_id=%s mad_code=%s", + verdict.event_id, + verdict.mad_code, + ) + return ("halt", _build_blocked_response(tool_calls)) + + return ("proceed", None) + + def _patch_tool_node() -> None: - """Patch ``ToolNode.invoke`` / ``ainvoke``. + """Patch ToolNode._afunc with the verdict gate, and public methods for callback injection. - In block mode, the async patch waits for the preceding LLM's verdict - before executing tools. On BLOCK (unless overridden by ``on_block``) - it returns synthetic ``ToolMessage`` responses instead of running the - tools. On timeout it fails open. + _afunc is the only reliable intercept -- Pregel bypasses ainvoke/astream entirely. """ try: from langgraph.prebuilt import ToolNode @@ -852,6 +950,22 @@ def _patch_tool_node() -> None: original_invoke = ToolNode.invoke original_ainvoke = ToolNode.ainvoke + original_astream = getattr(ToolNode, "astream", None) + original_stream = getattr(ToolNode, "stream", None) + original_afunc = ToolNode._afunc # type: ignore[attr-defined] + + async def patched_afunc( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + runtime: Any = None, # noqa: ANN401 + ) -> Any: # noqa: ANN401 + """Verdict gate on ToolNode._afunc.""" + decision, blocked = await _adrian_tool_gate(input) + if decision == "halt": + return blocked + + return await original_afunc(self, input, config=config, runtime=runtime) def patched_invoke( self: Any, # noqa: ANN401 @@ -859,7 +973,7 @@ def patched_invoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into sync ToolNode invocation.""" + """Inject Adrian callbacks into ToolNode.invoke.""" config = _inject_callbacks(config) return original_invoke(self, input, config=config, **kwargs) @@ -870,75 +984,194 @@ async def patched_ainvoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks; in BLOCK / HITL modes wait for verdict. + """Inject Adrian callbacks into ToolNode.ainvoke.""" + config = _inject_callbacks(config) - Per-tool-call correlation: every tool_call.id is mapped (in - ``WebSocketClient`` ) to the event_id of the LLM that emitted - it. Each ToolNode invocation awaits its specific LLM's verdict, - race-free under parallel agents, no graph-wide pause. - """ + return await original_ainvoke(self, input, config=config, **kwargs) + + async def patched_astream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + """Inject Adrian callbacks into ToolNode.astream.""" config = _inject_callbacks(config) - ws = _ws_client - if ws is None: - return await original_ainvoke(self, input, config=config, **kwargs) + async for chunk in original_astream(self, input, config=config, **kwargs): + yield chunk - # First-tool-call window: the recv loop may not have processed - # ``LoginAck`` yet, so ``policy_active()`` reads False even - # when the org is in BLOCK or HITL. Wait for the LoginAck - # event before checking. If it doesn't arrive within the - # window, halt, refusing to run is the only safe outcome - # when we can't verify the org's policy. - if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] - try: - await asyncio.wait_for( - ws._login_ack_received.wait(), # pyright: ignore[reportPrivateUsage] - timeout=5.0, - ) - except TimeoutError: - logger.warning( - "ToolNode: LoginAck not received within 5s; halting " - "(refusing to run a tool without a verified policy)" - ) - return _build_blocked_response(_extract_tool_calls(input)) - - if not ws.policy_active(): - return await original_ainvoke(self, input, config=config, **kwargs) + ToolNode._afunc = patched_afunc # type: ignore[attr-defined] + ToolNode.invoke = patched_invoke # type: ignore[assignment] + ToolNode.ainvoke = patched_ainvoke # type: ignore[assignment] + if original_astream is not None: + ToolNode.astream = patched_astream # type: ignore[assignment] + ToolNode._adrian_tool_node_patched = True # type: ignore[attr-defined] + logger.debug("Patched ToolNode._afunc / invoke / ainvoke / astream") - tool_calls = _extract_tool_calls(input) - tool_call_id = next( - (tc.get("id") for tc in tool_calls if tc.get("id")), - None, - ) - if not tool_call_id: - # Direct ToolNode invocation outside an LLM flow, no - # producing event_id to wait on, so let the tool run. - return await original_ainvoke(self, input, config=config, **kwargs) +# --- 6. AgentExecutor (langchain / langchain-classic) --- - cfg = _get_config() - timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) - verdict = await ws.wait_for_tool_call_verdict(tool_call_id, timeout) +_BLOCKED_OBSERVATION = "[BLOCKED by security policy]" - if verdict is None: - logger.warning( - "verdict timeout for tool_call_id=%s, fail-open", - tool_call_id, - ) - return await original_ainvoke(self, input, config=config, **kwargs) - if _should_halt(verdict): - logger.warning( - "halting tool execution for event_id=%s mad_code=%s", - verdict.event_id, - verdict.mad_code, - ) - return _build_blocked_response(tool_calls) +def _patch_agent_executor() -> None: + """Patch AgentExecutor tool dispatch with the verdict gate. - return await original_ainvoke(self, input, config=config, **kwargs) + Covers the legacy AgentExecutor path which bypasses ToolNode entirely. + Falls through for ReAct parsers that don't emit tool_call_id. + """ + AgentExecutor = None + AgentStep = None + for mod_path in ("langchain_classic.agents.agent", "langchain.agents.agent"): + try: + mod = __import__(mod_path, fromlist=["AgentExecutor", "AgentStep"]) + AgentExecutor = getattr(mod, "AgentExecutor", None) + AgentStep = getattr(mod, "AgentStep", None) + if AgentExecutor and AgentStep: + break + except ImportError: + continue + + if AgentExecutor is None or AgentStep is None: + return - ToolNode.invoke = patched_invoke # type: ignore[assignment] - ToolNode.ainvoke = patched_ainvoke # type: ignore[assignment] - ToolNode._adrian_tool_node_patched = True # type: ignore[attr-defined] - logger.debug("Patched ToolNode.invoke / ainvoke") + if getattr(AgentExecutor, "_adrian_executor_patched", False): + return + + original_aperform = AgentExecutor._aperform_agent_action + original_perform = AgentExecutor._perform_agent_action + + async def patched_aperform( + self: Any, # noqa: ANN401 + name_to_tool_map: Any, # noqa: ANN401 + color_mapping: Any, # noqa: ANN401 + agent_action: Any, # noqa: ANN401 + run_manager: Any = None, # noqa: ANN401 + ) -> Any: # noqa: ANN401 + """Verdict gate before AgentExecutor dispatches a tool (async).""" + tool_call_id = getattr(agent_action, "tool_call_id", None) + + if tool_call_id: + ws = _ws_client + + if ws is not None: + if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] + try: + await asyncio.wait_for( + ws._login_ack_received.wait(), # pyright: ignore[reportPrivateUsage] + timeout=5.0, + ) + except TimeoutError: + logger.warning( + "AgentExecutor: LoginAck not received within 5s; " + "blocking tool %s", + agent_action.tool, + ) + return AgentStep( + action=agent_action, + observation=_BLOCKED_OBSERVATION, + ) + + if ws.policy_active(): + cfg = _get_config() + # Short timeout: AgentExecutor LLM callbacks may not propagate, + # so verdicts may never arrive. + cfg = _get_config() + timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) + verdict = await ws.wait_for_tool_call_verdict( + tool_call_id, timeout, + ) + + if verdict is not None and _should_halt(verdict): + logger.warning( + "halting tool execution for event_id=%s " + "mad_code=%s (AgentExecutor path)", + verdict.event_id, + verdict.mad_code, + ) + return AgentStep( + action=agent_action, + observation=_BLOCKED_OBSERVATION, + ) + + if verdict is None: + logger.warning( + "AgentExecutor: verdict timeout for " + "tool_call_id=%s, fail-open", + tool_call_id, + ) + + return await original_aperform( + self, name_to_tool_map, color_mapping, agent_action, run_manager, + ) + + def patched_perform( + self: Any, # noqa: ANN401 + name_to_tool_map: Any, # noqa: ANN401 + color_mapping: Any, # noqa: ANN401 + agent_action: Any, # noqa: ANN401 + run_manager: Any = None, # noqa: ANN401 + ) -> Any: # noqa: ANN401 + """Verdict gate before AgentExecutor dispatches a tool (sync).""" + tool_call_id = getattr(agent_action, "tool_call_id", None) + + if tool_call_id: + ws = _ws_client + + if ws is not None and ws._login_ack_received.is_set() and ws.policy_active(): # pyright: ignore[reportPrivateUsage] + import concurrent.futures + + async def _gate() -> bool: + cfg = _get_config() + timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) + verdict = await ws.wait_for_tool_call_verdict( + tool_call_id, timeout, + ) + if verdict is not None and _should_halt(verdict): + logger.warning( + "halting tool execution for event_id=%s " + "mad_code=%s (AgentExecutor sync path)", + verdict.event_id, + verdict.mad_code, + ) + return True + return False + + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + future: concurrent.futures.Future[bool] = concurrent.futures.Future() + + async def _run() -> None: + try: + result = await _gate() + future.set_result(result) + except Exception as exc: + future.set_exception(exc) + + loop.create_task(_run()) + should_block = future.result(timeout=35) + else: + should_block = loop.run_until_complete(_gate()) + + if should_block: + return AgentStep( + action=agent_action, + observation=_BLOCKED_OBSERVATION, + ) + except Exception: + logger.debug( + "AgentExecutor sync gate failed, falling through", + exc_info=True, + ) + + return original_perform( + self, name_to_tool_map, color_mapping, agent_action, run_manager, + ) + + AgentExecutor._aperform_agent_action = patched_aperform # type: ignore[assignment] + AgentExecutor._perform_agent_action = patched_perform # type: ignore[assignment] + AgentExecutor._adrian_executor_patched = True # type: ignore[attr-defined] + logger.debug("Patched AgentExecutor._aperform_agent_action / _perform_agent_action") diff --git a/sdk/adrian/ws.py b/sdk/adrian/ws.py index 1ab5df4..30f1ab5 100644 --- a/sdk/adrian/ws.py +++ b/sdk/adrian/ws.py @@ -513,6 +513,18 @@ async def connect(self) -> None: else: logger.info("WebSocket connected: %s", self._url) + # Eager login: send the SessionLogin frame immediately + # so the server responds with LoginAck before any tool + # gate fires. Previously login was deferred to the + # first _send_frame call, which meant frameworks that + # don't trigger paired events (AgentExecutor) would + # never receive LoginAck and the block gate would time + # out. Provider/model are best-effort at this point + # (empty until the first LLM event auto-detects them). + if not self._logged_in: + await self._send_login(self._ws) + self._logged_in = True + # Drain anything buffered while we were offline, even # on the very first connect. ``_send_mcp_inventory`` # and other init-time emitters queue frames before the