From 3bdfde653cca7783388361495c9bcc3d486edf23 Mon Sep 17 00:00:00 2001 From: Ritwij Aryan Parmar Date: Fri, 29 May 2026 13:18:22 -0400 Subject: [PATCH] Add skill lanes for MCP tool serialization Signed-off-by: Ritwij Aryan Parmar --- dimos/agents/annotation.py | 28 +++++++++++++-- dimos/agents/mcp/mcp_server.py | 31 +++++++++++++--- dimos/agents/mcp/test_mcp_server.py | 56 +++++++++++++++++++++++++++++ dimos/agents/test_skill_result.py | 17 +++++++++ dimos/core/module.py | 6 +++- 5 files changed, 130 insertions(+), 8 deletions(-) diff --git a/dimos/agents/annotation.py b/dimos/agents/annotation.py index ed4c6235e5..4c5090ee0b 100644 --- a/dimos/agents/annotation.py +++ b/dimos/agents/annotation.py @@ -18,7 +18,7 @@ import inspect import threading import time -from typing import Any, TypeVar, cast +from typing import Any, TypeVar, cast, overload from dimos.core.core import rpc from dimos.utils.logging_config import setup_logger @@ -66,7 +66,10 @@ def _stamp_and_log(func_name: str, result: Any, elapsed_ms: float) -> Any: return result -def skill(func: F) -> F: +def _decorate_skill(func: F, *, lane: str | None = None) -> F: + if lane == "": + raise ValueError("skill lane must be a non-empty string or None") + if inspect.iscoroutinefunction(func): @functools.wraps(func) @@ -108,4 +111,25 @@ def sync_context_wrapper(*args: Any, **kwargs: Any) -> Any: wrapped = rpc(context_wrapper) wrapped.__skill__ = True # type: ignore[attr-defined] + wrapped.__skill_lane__ = lane # type: ignore[attr-defined] return cast("F", wrapped) + + +@overload +def skill(func: F) -> F: ... + + +@overload +def skill(*, lane: str | None = None) -> Callable[[F], F]: ... + + +def skill(func: F | None = None, *, lane: str | None = None) -> F | Callable[[F], F]: + """Mark a method as an agent skill. + + `lane` optionally names a sequential execution lane. MCP callers serialize + calls within the same lane while leaving unlaned or differently-laned skills + free to run concurrently. + """ + if func is None: + return lambda wrapped_func: _decorate_skill(wrapped_func, lane=lane) + return _decorate_skill(func, lane=lane) diff --git a/dimos/agents/mcp/mcp_server.py b/dimos/agents/mcp/mcp_server.py index dbd31f8d87..e32fab5625 100644 --- a/dimos/agents/mcp/mcp_server.py +++ b/dimos/agents/mcp/mcp_server.py @@ -54,6 +54,7 @@ app.state.rpc_calls = {} app.state.sse_queues = [] app.state.event_loop = None +app.state.lane_locks = {} def _jsonrpc_result(req_id: Any, result: Any) -> dict[str, Any]: @@ -89,13 +90,19 @@ def _handle_tools_list(req_id: Any, skills: list[SkillInfo]) -> dict[str, Any]: tool: dict[str, Any] = {"name": s.func_name, "inputSchema": schema} if description: tool["description"] = description + if s.lane is not None: + tool["_meta"] = {"lane": s.lane} tools.append(tool) return _jsonrpc_result(req_id, {"tools": tools}) async def _handle_tools_call( - req_id: Any, params: dict[str, Any], rpc_calls: dict[str, Any] + req_id: Any, + params: dict[str, Any], + skills: list[SkillInfo], + rpc_calls: dict[str, Any], + lane_locks: dict[str, asyncio.Lock], ) -> dict[str, Any]: name = params.get("name", "") args: dict[str, Any] = params.get("arguments") or {} @@ -107,6 +114,8 @@ async def _handle_tools_call( logger.warning("MCP tool not found", tool=name) return _jsonrpc_result_text(req_id, f"Tool not found: {name}") + lane = next((s.lane for s in skills if s.func_name == name), None) + logger.info("MCP tool call", tool=name, args=args, progress_token=progress_token) t0 = time.monotonic() @@ -116,10 +125,16 @@ async def _handle_tools_call( if progress_token is not None: call_kwargs["_mcp_context"] = {"progress_token": progress_token} + async def run_rpc_call() -> Any: + return await asyncio.get_event_loop().run_in_executor(None, lambda: rpc_call(**call_kwargs)) + try: - result = await asyncio.get_event_loop().run_in_executor( - None, lambda: rpc_call(**call_kwargs) - ) + if lane is None: + result = await run_rpc_call() + else: + lock = lane_locks.setdefault(lane, asyncio.Lock()) + async with lock: + result = await run_rpc_call() except Exception as e: logger.exception("MCP tool error", tool=name, duration=f"{time.monotonic() - t0:.3f}s") return _jsonrpc_result_text(req_id, f"Error running tool '{name}': {e}") @@ -158,7 +173,13 @@ async def handle_request( if method == "tools/list": return _handle_tools_list(req_id, skills) if method == "tools/call": - return await _handle_tools_call(req_id, params, rpc_calls) + return await _handle_tools_call( + req_id, + params, + skills, + rpc_calls, + app.state.lane_locks, + ) return _jsonrpc_error(req_id, -32601, f"Unknown: {method}") diff --git a/dimos/agents/mcp/test_mcp_server.py b/dimos/agents/mcp/test_mcp_server.py index fd514b0643..00f0f88ebe 100644 --- a/dimos/agents/mcp/test_mcp_server.py +++ b/dimos/agents/mcp/test_mcp_server.py @@ -16,6 +16,8 @@ import asyncio import json +import threading +import time from unittest.mock import MagicMock from dimos.agents.mcp.mcp_server import handle_request @@ -68,6 +70,60 @@ def test_mcp_module_request_flow() -> None: rpc_calls["add"].assert_called_once_with(x=2, y=3) +def test_mcp_module_lists_skill_lane_metadata() -> None: + schema = json.dumps({"type": "object", "properties": {}}) + skills = [ + SkillInfo(class_name="TestSkills", func_name="drive", args_schema=schema, lane="motion") + ] + rpc_calls = _make_rpc_calls(skills, {"drive": "ok"}) + + response = asyncio.run(handle_request({"method": "tools/list", "id": 1}, skills, rpc_calls)) + + assert response["result"]["tools"][0]["_meta"] == {"lane": "motion"} + + +def test_mcp_module_serializes_same_lane_calls() -> None: + schema = json.dumps({"type": "object", "properties": {}}) + skills = [ + SkillInfo(class_name="TestSkills", func_name="drive", args_schema=schema, lane="motion") + ] + rpc_calls = _make_rpc_calls(skills, {}) + active = 0 + max_active = 0 + lock = threading.Lock() + + def slow_drive() -> str: + nonlocal active, max_active + with lock: + active += 1 + max_active = max(max_active, active) + time.sleep(0.05) + with lock: + active -= 1 + return "ok" + + rpc_calls["drive"].side_effect = slow_drive + + async def run_two_calls() -> None: + await asyncio.gather( + handle_request( + {"method": "tools/call", "id": 1, "params": {"name": "drive", "arguments": {}}}, + skills, + rpc_calls, + ), + handle_request( + {"method": "tools/call", "id": 2, "params": {"name": "drive", "arguments": {}}}, + skills, + rpc_calls, + ), + ) + + asyncio.run(run_two_calls()) + + assert rpc_calls["drive"].call_count == 2 + assert max_active == 1 + + def test_mcp_module_injects_progress_token_as_mcp_context() -> None: """When the client sends `_meta.progressToken`, the RPC call receives it as an `_mcp_context` kwarg so the `@skill` wrapper can stash it in the diff --git a/dimos/agents/test_skill_result.py b/dimos/agents/test_skill_result.py index c4aae76a17..97914be8d9 100644 --- a/dimos/agents/test_skill_result.py +++ b/dimos/agents/test_skill_result.py @@ -206,3 +206,20 @@ def my_skill() -> SkillResult: assert sentinel.duration_ms == 999.0 # untouched # Decorator overwrites with actual measured elapsed (very small). assert result.duration_ms != 999.0 + + def test_parameterized_decorator_records_lane(self): + @skill(lane="motion") + def go_home() -> SkillResult: + return SkillResult.ok("done") + + assert go_home.__skill__ is True + assert go_home.__skill_lane__ == "motion" + assert go_home().is_success() + + def test_bare_decorator_records_no_lane(self): + @skill + def inspect_scene() -> SkillResult: + return SkillResult.ok("done") + + assert inspect_scene.__skill__ is True + assert inspect_scene.__skill_lane__ is None diff --git a/dimos/core/module.py b/dimos/core/module.py index f2aed9d185..208abd1151 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -69,6 +69,7 @@ class SkillInfo: class_name: str func_name: str args_schema: str + lane: str | None = None class PeekNotFound: @@ -440,7 +441,10 @@ def get_skills(self) -> list[SkillInfo]: schema = json.dumps(tool(attr).args_schema.model_json_schema()) skills.append( SkillInfo( - class_name=self.__class__.__name__, func_name=name, args_schema=schema + class_name=self.__class__.__name__, + func_name=name, + args_schema=schema, + lane=getattr(attr, "__skill_lane__", None), ) ) return skills