From 97d5b38372bf4a351155c82536b45bfd625348e8 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 29 May 2026 15:50:09 +0000 Subject: [PATCH] Bind transport sessions to the authenticated principal MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Both HTTP transports now record the principal that created each session — the OAuth client together with the issuer and subject when the token verifier supplies them — and serve subsequent requests for that session only when they present the same principal. Requests presenting a different principal receive the same 404 response as for an unknown session ID, and SSE session entries are removed when the connection ends. Servers without authentication, and authentication backends other than the built-in BearerAuthBackend, are unaffected: no principal is recorded and the comparison always passes. --- src/mcp/server/auth/middleware/bearer_auth.py | 26 ++- src/mcp/server/sse.py | 65 ++++--- src/mcp/server/streamable_http_manager.py | 49 +++-- tests/server/test_sse_security.py | 135 +++++++++++++- tests/server/test_streamable_http_manager.py | 167 +++++++++++++++++- 5 files changed, 404 insertions(+), 38 deletions(-) diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 64c9b8841f..300b298924 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -1,6 +1,6 @@ import json import time -from typing import Any +from typing import Any, TypedDict from pydantic import AnyHttpUrl from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser @@ -19,6 +19,30 @@ def __init__(self, auth_info: AccessToken): self.scopes = auth_info.scopes +class AuthorizationContext(TypedDict): + client_id: str + issuer: str | None + subject: str | None + + +def authorization_context(user: AuthenticatedUser) -> AuthorizationContext: + """Identify the principal `user` represents, for transports to compare + against the principal that created a session. Components the token + verifier does not supply are `None`, so the comparison degrades to the + remaining components. + + See `examples/servers/simple-auth/mcp_simple_auth/token_verifier.py` for + a verifier that populates `subject` and `claims` from an introspection + response.""" + token = user.access_token + issuer = (token.claims or {}).get("iss") + return AuthorizationContext( + client_id=token.client_id, + issuer=str(issuer) if issuer is not None else None, + subject=token.subject, + ) + + class BearerAuthBackend(AuthenticationBackend): """ Authentication backend that validates Bearer tokens using a TokenVerifier. diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 19af93fd16..489785c4c9 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -52,6 +52,7 @@ async def handle_sse(request): from starlette.types import Receive, Scope, Send import mcp.types as types +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context from mcp.server.transport_security import ( TransportSecurityMiddleware, TransportSecuritySettings, @@ -75,6 +76,9 @@ class SseServerTransport: _endpoint: str _read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]] + # Identity of the credential that created each session; requests for a + # session must present the same credential. + _session_owners: dict[UUID, AuthorizationContext] _security: TransportSecurityMiddleware def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None: @@ -115,6 +119,7 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | self._endpoint = endpoint self._read_stream_writers = {} + self._session_owners = {} self._security = TransportSecurityMiddleware(security_settings) logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") @@ -142,6 +147,9 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # prag write_stream, write_stream_reader = anyio.create_memory_object_stream(0) session_id = uuid4() + user = scope.get("user") + if isinstance(user, AuthenticatedUser): + self._session_owners[session_id] = authorization_context(user) self._read_stream_writers[session_id] = read_stream_writer logger.debug(f"Created new session with ID: {session_id}") @@ -177,26 +185,34 @@ async def sse_writer(): } ) - async with anyio.create_task_group() as tg: - - async def response_wrapper(scope: Scope, receive: Receive, send: Send): - """ - The EventSourceResponse returning signals a client close / disconnect. - In this case we close our side of the streams to signal the client that - the connection has been closed. - """ - await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)( - scope, receive, send - ) - await read_stream_writer.aclose() - await write_stream_reader.aclose() - logging.debug(f"Client session disconnected {session_id}") - - logger.debug("Starting SSE response task") - tg.start_soon(response_wrapper, scope, receive, send) - - logger.debug("Yielding read and write streams") - yield (read_stream, write_stream) + try: + async with anyio.create_task_group() as tg: + + async def response_wrapper(scope: Scope, receive: Receive, send: Send): + """ + The EventSourceResponse returning signals a client close / disconnect. + In this case we close our side of the streams to signal the client that + the connection has been closed. + """ + await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)( + scope, receive, send + ) + await read_stream_writer.aclose() + await write_stream_reader.aclose() + await sse_stream_reader.aclose() + logging.debug(f"Client session disconnected {session_id}") + + logger.debug("Starting SSE response task") + tg.start_soon(response_wrapper, scope, receive, send) + + logger.debug("Yielding read and write streams") + yield (read_stream, write_stream) + finally: + # The connection is gone: stop routing messages to this session + # and drop its entries so they do not accumulate for the lifetime + # of the transport. + self._read_stream_writers.pop(session_id, None) + self._session_owners.pop(session_id, None) async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: no cover logger.debug("Handling POST message") @@ -227,6 +243,15 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) response = Response("Could not find session", status_code=404) return await response(scope, receive, send) + user = scope.get("user") + requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None + if requestor != self._session_owners.get(session_id): + # A session can only be used with the credential that created it. + # Respond exactly as if the session did not exist. + logger.warning("Rejecting message for session %s: credential does not match", session_id) + response = Response("Could not find session", status_code=404) + return await response(scope, receive, send) + body = await request.body() logger.debug(f"Received JSON: {body}") diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 8a7b765e86..1a1a85721d 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -5,7 +5,6 @@ import contextlib import logging from collections.abc import AsyncIterator -from http import HTTPStatus from typing import Any from uuid import uuid4 @@ -15,6 +14,7 @@ from starlette.responses import Response from starlette.types import Receive, Scope, Send +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context from mcp.server.lowlevel.server import Server as MCPServer from mcp.server.streamable_http import ( MCP_SESSION_ID_HEADER, @@ -88,6 +88,9 @@ def __init__( # Session tracking (only used if not stateless) self._session_creation_lock = anyio.Lock() self._server_instances: dict[str, StreamableHTTPServerTransport] = {} + # Identity of the credential that created each session; requests for a + # session must present the same credential. + self._session_owners: dict[str, AuthorizationContext] = {} # The task group will be set during lifespan self._task_group = None @@ -135,6 +138,7 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]: self._task_group = None # Clear any remaining server instances self._server_instances.clear() + self._session_owners.clear() async def handle_request( self, @@ -227,12 +231,32 @@ async def _handle_stateful_request( request = Request(scope, receive) request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) + user = scope.get("user") + requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None + # Existing session case - if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: # pragma: no cover + if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: transport = self._server_instances[request_mcp_session_id] + if requestor != self._session_owners.get(request_mcp_session_id): + # A session can only be used with the credential that created + # it. Respond exactly as if the session did not exist. + logger.warning( + "Rejecting request for session %s: credential does not match the one that created the session", + request_mcp_session_id[:64], + ) + body = JSONRPCError( + jsonrpc="2.0", id="server-error", error=ErrorData(code=INVALID_REQUEST, message="Session not found") + ) + response = Response( + body.model_dump_json(by_alias=True, exclude_none=True), + status_code=404, + media_type="application/json", + ) + await response(scope, receive, send) + return logger.debug("Session already exists, handling request directly") # Push back idle deadline on activity - if transport.idle_scope is not None and self.session_idle_timeout is not None: + if transport.idle_scope is not None and self.session_idle_timeout is not None: # pragma: no cover transport.idle_scope.deadline = anyio.current_time() + self.session_idle_timeout await transport.handle_request(scope, receive, send) return @@ -251,6 +275,8 @@ async def _handle_stateful_request( ) assert http_transport.mcp_session_id is not None + if requestor is not None: + self._session_owners[http_transport.mcp_session_id] = requestor self._server_instances[http_transport.mcp_session_id] = http_transport logger.info(f"Created new transport with session ID: {new_session_id}") @@ -281,6 +307,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE assert http_transport.mcp_session_id is not None logger.info(f"Session {http_transport.mcp_session_id} idle timeout") self._server_instances.pop(http_transport.mcp_session_id, None) + self._session_owners.pop(http_transport.mcp_session_id, None) await http_transport.terminate() except Exception: logger.exception(f"Session {http_transport.mcp_session_id} crashed") @@ -296,6 +323,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE "active instances." ) del self._server_instances[http_transport.mcp_session_id] + self._session_owners.pop(http_transport.mcp_session_id, None) # Assert task group is not None for type checking assert self._task_group is not None @@ -306,19 +334,10 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE await http_transport.handle_request(scope, receive, send) else: # Unknown or expired session ID - return 404 per MCP spec - # TODO: Align error code once spec clarifies - # See: https://github.com/modelcontextprotocol/python-sdk/issues/1821 - error_response = JSONRPCError( - jsonrpc="2.0", - id="server-error", - error=ErrorData( - code=INVALID_REQUEST, - message="Session not found", - ), + body = JSONRPCError( + jsonrpc="2.0", id="server-error", error=ErrorData(code=INVALID_REQUEST, message="Session not found") ) response = Response( - content=error_response.model_dump_json(by_alias=True, exclude_none=True), - status_code=HTTPStatus.NOT_FOUND, - media_type="application/json", + body.model_dump_json(by_alias=True, exclude_none=True), status_code=404, media_type="application/json" ) await response(scope, receive, send) diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index 010eaf6a25..716a308a53 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -1,9 +1,13 @@ -"""Tests for SSE server DNS rebinding protection.""" +"""Tests for SSE server request validation.""" import logging import multiprocessing +import re import socket +from collections.abc import Iterator +from typing import Any +import anyio import httpx import pytest import uvicorn @@ -11,8 +15,11 @@ from starlette.requests import Request from starlette.responses import Response from starlette.routing import Mount, Route +from starlette.types import Message from mcp.server import Server +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from mcp.server.auth.provider import AccessToken from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings from mcp.types import Tool @@ -22,6 +29,23 @@ SERVER_NAME = "test_sse_security_server" +@pytest.fixture(autouse=True) +def reset_sse_starlette_exit_event() -> Iterator[None]: + """sse-starlette<2 caches a module-level anyio.Event on AppStatus; clear it + around each test so it is never bound to a closed event loop. Clearing it + afterwards matters too: later test modules fork uvicorn subprocesses on + Linux and would otherwise inherit a stale event.""" + from sse_starlette.sse import AppStatus + + def clear() -> None: + if hasattr(AppStatus, "should_exit_event"): # pragma: no cover + setattr(AppStatus, "should_exit_event", None) + + clear() + yield + clear() + + @pytest.fixture def server_port() -> int: with socket.socket() as s: @@ -291,3 +315,112 @@ async def test_sse_security_post_valid_content_type(server_port: int): finally: process.terminate() process.join() + + +def _authenticated_user(client_id: str, subject: str | None = None, issuer: str | None = None) -> AuthenticatedUser: + """Build the scope["user"] value that AuthenticationMiddleware would set for this principal.""" + claims = {"iss": issuer} if issuer is not None else None + return AuthenticatedUser(AccessToken(token="token", client_id=client_id, scopes=[], subject=subject, claims=claims)) + + +def _sse_scope(method: str, path: str, user: AuthenticatedUser | None) -> dict[str, Any]: + """Build an ASGI scope for a request to the SSE transport.""" + scope: dict[str, Any] = { + "type": "http", + "method": method, + "path": path, + "root_path": "", + "query_string": b"", + "headers": [(b"content-type", b"application/json")], + } + if user is not None: + scope["user"] = user + return scope + + +async def _post_message(transport: SseServerTransport, session_id: str, user: AuthenticatedUser | None) -> int: + """POST a message to an SSE session as `user` and return the response status.""" + body = b'{"jsonrpc": "2.0", "id": 1, "method": "ping", "params": null}' + scope = _sse_scope("POST", "/messages/", user) + scope["query_string"] = f"session_id={session_id}".encode() + sent: list[Message] = [] + + async def receive() -> Message: + return {"type": "http.request", "body": body, "more_body": False} + + async def send(message: Message) -> None: + sent.append(message) + + await transport.handle_post_message(scope, receive, send) + response_start = next(msg for msg in sent if msg["type"] == "http.response.start") + return response_start["status"] + + +_Principal = tuple[str] | tuple[str, str] | tuple[str, str, str] + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("creator", "sender", "expected"), + [ + pytest.param(("client-a",), ("client-b",), 404, id="different-client"), + pytest.param(("client-a",), None, 404, id="unauthenticated-sender"), + pytest.param(("client-a", "alice"), ("client-a", "bob"), 404, id="same-client-different-subject"), + pytest.param(("client-a", "alice"), ("client-a",), 404, id="same-client-no-subject"), + pytest.param( + ("client-a", "alice", "https://i1"), ("client-a", "alice", "https://i2"), 404, id="different-issuer" + ), + pytest.param(None, ("client-a",), 404, id="unauthenticated-creator"), + pytest.param(("client-a",), ("client-a",), 202, id="same-client"), + pytest.param(("client-a", "alice"), ("client-a", "alice"), 202, id="same-client-and-subject"), + pytest.param(None, None, 202, id="both-unauthenticated"), + ], +) +async def test_sse_post_requires_the_credential_that_created_the_session( + creator: _Principal | None, + sender: _Principal | None, + expected: int, +): + """The session endpoint URL issued to one authenticated principal must not + accept messages from a request authenticated as a different one.""" + transport = SseServerTransport("/messages/") + session_id_received = anyio.Event() + session_ids: list[str] = [] + client_disconnected = anyio.Event() + + async def get_send(message: Message) -> None: + # The first body chunk is the SSE event announcing the session URI to POST messages to. + if message["type"] == "http.response.body" and not session_ids: + match = re.search(rb"session_id=([0-9a-f]{32})", message.get("body", b"")) + assert match is not None, f"expected the endpoint event first, got {message!r}" + session_ids.append(match.group(1).decode()) + session_id_received.set() + + async def get_receive() -> Message: + # The SSE client stays connected until the test signals otherwise. + await client_disconnected.wait() + return {"type": "http.disconnect"} + + creator_user = _authenticated_user(*creator) if creator is not None else None + sender_user = _authenticated_user(*sender) if sender is not None else None + + async def hold_sse_connection() -> None: + """Establish the SSE session as `creator` and keep it open, as a server would.""" + scope = _sse_scope("GET", "/sse", creator_user) + with anyio.fail_after(5): + async with transport.connect_sse(scope, get_receive, get_send) as (read_stream, write_stream): + async with read_stream, write_stream: + async for _ in read_stream: + pass + + async with anyio.create_task_group() as tg: + tg.start_soon(hold_sse_connection) + with anyio.fail_after(5): + await session_id_received.wait() + + assert await _post_message(transport, session_ids[0], sender_user) == expected + + client_disconnected.set() + + # Once the connection is gone the session is no longer routable. + assert await _post_message(transport, session_ids[0], creator_user) == 404 diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 33bcb5f2aa..0ae07c43ad 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -6,9 +6,11 @@ import anyio import pytest -from starlette.types import Message +from starlette.types import Message, Scope from mcp.server import streamable_http_manager +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from mcp.server.auth.provider import AccessToken from mcp.server.lowlevel import Server from mcp.server.streamable_http import MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport from mcp.server.streamable_http_manager import StreamableHTTPSessionManager @@ -390,3 +392,166 @@ def test_session_idle_timeout_rejects_non_positive(): def test_session_idle_timeout_rejects_stateless(): with pytest.raises(RuntimeError, match="not supported in stateless"): StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=30, stateless=True) + + +def _user(client_id: str, subject: str | None = None, issuer: str | None = None) -> AuthenticatedUser: + """Build the scope["user"] value that AuthenticationMiddleware would set for this principal.""" + claims = {"iss": issuer} if issuer is not None else None + return AuthenticatedUser(AccessToken(token="token", client_id=client_id, scopes=[], subject=subject, claims=claims)) + + +def _request_scope( + *, session_id: str | None = None, user: AuthenticatedUser | None = None, method: str = "POST" +) -> Scope: + """Build an ASGI scope for a request to the MCP endpoint.""" + headers = [ + (b"content-type", b"application/json"), + (b"accept", b"application/json, text/event-stream"), + ] + if session_id is not None: + headers.append((b"mcp-session-id", session_id.encode())) + scope: Scope = { + "type": "http", + "method": method, + "path": "/mcp", + "headers": headers, + } + if user is not None: + scope["user"] = user + return scope + + +async def _open_session(manager: StreamableHTTPSessionManager, user: AuthenticatedUser | None) -> str: + """Create a new session as `user` and return its session ID.""" + sent_messages: list[Message] = [] + + async def mock_send(message: Message) -> None: + sent_messages.append(message) + + async def mock_receive() -> Message: + return {"type": "http.request", "body": b"", "more_body": False} + + await manager.handle_request(_request_scope(user=user), mock_receive, mock_send) + + response_start = next(msg for msg in sent_messages if msg["type"] == "http.response.start") + headers = dict(response_start.get("headers", [])) + return headers[MCP_SESSION_ID_HEADER.encode()].decode() + + +async def _request_session( + manager: StreamableHTTPSessionManager, session_id: str, user: AuthenticatedUser | None, method: str = "POST" +) -> int: + """Send a request for an existing session as `user` and return the response status.""" + sent_messages: list[Message] = [] + + async def mock_send(message: Message) -> None: + sent_messages.append(message) + + async def mock_receive() -> Message: + return {"type": "http.request", "body": b"", "more_body": False} + + await manager.handle_request( + _request_scope(session_id=session_id, user=user, method=method), mock_receive, mock_send + ) + + response_start = next(msg for msg in sent_messages if msg["type"] == "http.response.start") + return response_start["status"] + + +@pytest.fixture +async def manager_with_live_session(): + """A running manager around a real `Server`. Sessions remain registered until + `manager.run()` exits because `Server.run` blocks waiting for an initialize message.""" + manager = StreamableHTTPSessionManager(app=Server("test-session-credentials")) + async with manager.run(): + yield manager + + +@pytest.mark.anyio +async def test_session_accepts_requests_from_the_credential_that_created_it( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """Requests presenting the same credential as the one that created the session are served.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a")) + + status = await _request_session(manager, session_id, _user("client-a")) + + # The request passes the manager's credential check and reaches the + # session's transport, instead of being answered with 404 by the manager. + assert status != 404 + + +@pytest.mark.anyio +@pytest.mark.parametrize("method", ["POST", "GET", "DELETE"]) +async def test_session_rejects_requests_from_a_different_credential( + manager_with_live_session: StreamableHTTPSessionManager, method: str +) -> None: + """A session created by one credential cannot be used with another credential, whatever the method.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a")) + + assert await _request_session(manager, session_id, _user("client-b"), method) == 404 + # The session is still registered and still serves its creator. + assert await _request_session(manager, session_id, _user("client-a")) != 404 + + +@pytest.mark.anyio +async def test_session_rejects_requests_from_a_different_subject_of_the_same_client( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """Two end-users that share an OAuth client cannot use each other's sessions.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a", subject="alice")) + + assert await _request_session(manager, session_id, _user("client-a", subject="bob")) == 404 + assert await _request_session(manager, session_id, _user("client-a", subject=None)) == 404 + assert await _request_session(manager, session_id, _user("client-a", subject="alice")) != 404 + + +@pytest.mark.anyio +async def test_session_rejects_requests_with_the_same_subject_from_a_different_issuer( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """A subject is unique only per issuer, so a colliding subject from a different issuer is not the same principal.""" + manager = manager_with_live_session + creator = _user("client-a", subject="alice", issuer="https://issuer.one") + session_id = await _open_session(manager, creator) + + other_issuer = _user("client-a", subject="alice", issuer="https://issuer.two") + assert await _request_session(manager, session_id, other_issuer) == 404 + assert await _request_session(manager, session_id, _user("client-a", subject="alice")) == 404 + assert await _request_session(manager, session_id, creator) != 404 + + +@pytest.mark.anyio +async def test_session_rejects_unauthenticated_requests_for_an_authenticated_session( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """A session created with a credential cannot be used without one.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a")) + + assert await _request_session(manager, session_id, None) == 404 + + +@pytest.mark.anyio +async def test_session_rejects_authenticated_requests_for_an_anonymous_session( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """A session created without a credential cannot be used with one.""" + manager = manager_with_live_session + session_id = await _open_session(manager, None) + + assert await _request_session(manager, session_id, _user("client-a")) == 404 + + +@pytest.mark.anyio +async def test_anonymous_session_accepts_anonymous_requests( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """Servers without authentication keep working: no credential on either side.""" + manager = manager_with_live_session + session_id = await _open_session(manager, None) + + assert await _request_session(manager, session_id, None) != 404