Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass
from urllib.parse import urlsplit

import anyio
import httpx
Expand Down Expand Up @@ -81,6 +82,19 @@ def __init__(self, url: str) -> None:
self.url = url
self.session_id: str | None = None
self.protocol_version: str | None = None
self._default_origin = self._derive_origin(url)

@staticmethod
def _derive_origin(url: str) -> str | None:
"""Derive a same-origin ``Origin`` value (scheme://host[:port]) from a URL.

Returns ``None`` when the URL has no scheme or host, in which case no
``Origin`` header is added.
"""
parsed = urlsplit(url)
if not parsed.scheme or not parsed.netloc:
return None
return f"{parsed.scheme}://{parsed.netloc}"

def _prepare_headers(self) -> dict[str, str]:
"""Build MCP-specific request headers.
Expand All @@ -92,6 +106,13 @@ def _prepare_headers(self) -> dict[str, str]:
"accept": "application/json, text/event-stream",
"content-type": "application/json",
}
# Send a same-origin Origin header by default so spec-compliant servers
# that enforce anti-DNS-rebinding / CSRF protection (e.g. the Go SDK's
# http.CrossOriginProtection) accept the handshake instead of returning
# 403. Callers needing a different Origin can set one on the underlying
# httpx client's default headers.
if self._default_origin is not None:
headers["origin"] = self._default_origin
# Add session headers if available
if self.session_id:
headers[MCP_SESSION_ID] = self.session_id
Expand Down
20 changes: 20 additions & 0 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -1776,6 +1776,26 @@ async def bad_client():
assert tools.tools


def test_prepare_headers_includes_same_origin():
"""Default Origin header is derived from the target URL (scheme://host[:port]).

Regression test for #2727: spec-compliant servers enforcing
anti-DNS-rebinding / CSRF protection reject requests with no Origin.
"""
transport = StreamableHTTPTransport(url="http://my-go-server:8081/mcp")
headers = transport._prepare_headers()
assert headers["origin"] == "http://my-go-server:8081"

https_transport = StreamableHTTPTransport(url="https://example.com/mcp/path?x=1")
assert https_transport._prepare_headers()["origin"] == "https://example.com"


def test_prepare_headers_omits_origin_for_invalid_url():
"""No Origin header is added when the URL lacks a scheme or host."""
transport = StreamableHTTPTransport(url="not-a-url")
assert "origin" not in transport._prepare_headers()


@pytest.mark.anyio
async def test_handle_sse_event_skips_empty_data():
"""Test that _handle_sse_event skips empty SSE data (keep-alive pings)."""
Expand Down
Loading