diff --git a/CLAUDE.md b/CLAUDE.md index 77d4270..59bdd5b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -21,7 +21,7 @@ If a task contradicts a confirmed decision, **stop and ask** before coding aroun ## Build commands -JDK 21 required. The Gradle wrapper is committed. +JDK 21 and [cosign](https://docs.sigstore.dev/cosign/system_config/installation/) required. The Gradle wrapper is committed. Install cosign via `brew install cosign` (macOS) — the build uses it to verify schema bundle signatures. ```bash ./gradlew build # full build, all 8 modules diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 016c6be..f6995f5 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -28,6 +28,7 @@ Before designing anything new, read the [21 confirmed post-RFC decisions](ROADMA Requirements: - JDK 21 (Temurin recommended) +- [cosign](https://docs.sigstore.dev/cosign/system_config/installation/) — the build shells out to `cosign verify-blob` to verify the schema bundle signature (per D4). Install via `brew install cosign` (macOS) or `go install github.com/sigstore/cosign/v2/cmd/cosign@latest`. - The Gradle wrapper (committed) — don't install Gradle separately. Local build: diff --git a/ROADMAP.md b/ROADMAP.md index b230361..92d3d83 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -223,7 +223,7 @@ The RFC's M+12 target is the realistic line. Pre-committing M+9 and slipping is ### Track 3 — L0 transport: MCP + A2A -**ID:** `transport` | **Owner:** TBD | **Size:** 1.5 person-months +**ID:** `transport` | **Owner:** @MichielDean (#17) | **Size:** 1.5 person-months **Scope:** @@ -528,6 +528,6 @@ Additional decisions added post-RFC that remain open: | Implementation plan drafted | ✅ (this doc) | | Confirmed decisions D1–D21 locked | ✅ | | Funding / staffing confirmed | ⏳ Decision pending | -| Tracks claimed | 0 / 14 | +| Tracks claimed | 3 / 14 — `infra` (Track 1, #2), `codegen` (Track 2, #11), `transport` (Track 3, #17) | | Pre-contributor harness | 🟡 In progress — Gradle skeleton, codegen MVP, SSRF skeleton, schema fetcher, mock-server CI gate, IPR workflow, commitlint, changesets, MCP prototype findings all landed. Foundation admin actions outstanding: IPR Bot install, DNS TXT for Sonatype, @MichielDean collaborator. | | v0.1 alpha | Not Started | diff --git a/adcp-cli/gradle.lockfile b/adcp-cli/gradle.lockfile index f5d76c1..3821060 100644 --- a/adcp-cli/gradle.lockfile +++ b/adcp-cli/gradle.lockfile @@ -2,13 +2,21 @@ # Manual edits can break the build and are not advised. # This file is expected to be part of source control. com.ethlo.time:itu:1.10.3=runtimeClasspath,testRuntimeClasspath -com.fasterxml.jackson.core:jackson-annotations:2.18.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -com.fasterxml.jackson.core:jackson-core:2.18.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -com.fasterxml.jackson.core:jackson-databind:2.18.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:2.18.2=runtimeClasspath,testRuntimeClasspath -com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.18.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -com.fasterxml.jackson:jackson-bom:2.18.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath +com.fasterxml.jackson.core:jackson-annotations:2.18.2=compileClasspath,testCompileClasspath +com.fasterxml.jackson.core:jackson-annotations:2.20=runtimeClasspath,testRuntimeClasspath +com.fasterxml.jackson.core:jackson-core:2.18.2=compileClasspath,testCompileClasspath +com.fasterxml.jackson.core:jackson-core:2.20.1=runtimeClasspath,testRuntimeClasspath +com.fasterxml.jackson.core:jackson-databind:2.18.2=compileClasspath,testCompileClasspath +com.fasterxml.jackson.core:jackson-databind:2.20.1=runtimeClasspath,testRuntimeClasspath +com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:2.20.1=runtimeClasspath,testRuntimeClasspath +com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.18.2=compileClasspath,testCompileClasspath +com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.20.1=runtimeClasspath,testRuntimeClasspath +com.fasterxml.jackson:jackson-bom:2.18.2=compileClasspath,testCompileClasspath +com.fasterxml.jackson:jackson-bom:2.20.1=runtimeClasspath,testRuntimeClasspath com.networknt:json-schema-validator:1.5.6=runtimeClasspath,testRuntimeClasspath +io.modelcontextprotocol.sdk:mcp-core:1.1.2=runtimeClasspath,testRuntimeClasspath +io.modelcontextprotocol.sdk:mcp-json-jackson2:1.1.2=runtimeClasspath,testRuntimeClasspath +io.projectreactor:reactor-core:3.7.0=runtimeClasspath,testRuntimeClasspath org.apiguardian:apiguardian-api:1.1.2=compileClasspath,testCompileClasspath org.jspecify:jspecify:1.0.0=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath org.junit.jupiter:junit-jupiter-api:5.11.4=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath @@ -19,7 +27,8 @@ org.junit.platform:junit-platform-engine:1.11.4=testRuntimeClasspath org.junit.platform:junit-platform-launcher:1.11.4=testRuntimeClasspath org.junit:junit-bom:5.11.4=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath org.opentest4j:opentest4j:1.3.0=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath +org.reactivestreams:reactive-streams:1.0.4=runtimeClasspath,testRuntimeClasspath org.slf4j:slf4j-api:2.0.16=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath org.slf4j:slf4j-simple:2.0.16=runtimeClasspath,testRuntimeClasspath -org.yaml:snakeyaml:2.3=runtimeClasspath,testRuntimeClasspath +org.yaml:snakeyaml:2.4=runtimeClasspath,testRuntimeClasspath empty=annotationProcessor,testAnnotationProcessor diff --git a/adcp-kotlin/gradle.lockfile b/adcp-kotlin/gradle.lockfile index 18db7ff..1710da3 100644 --- a/adcp-kotlin/gradle.lockfile +++ b/adcp-kotlin/gradle.lockfile @@ -2,13 +2,21 @@ # Manual edits can break the build and are not advised. # This file is expected to be part of source control. com.ethlo.time:itu:1.10.3=runtimeClasspath,testRuntimeClasspath -com.fasterxml.jackson.core:jackson-annotations:2.18.2=apiDependenciesMetadata,compileClasspath,implementationDependenciesMetadata,runtimeClasspath,testCompileClasspath,testImplementationDependenciesMetadata,testRuntimeClasspath -com.fasterxml.jackson.core:jackson-core:2.18.2=apiDependenciesMetadata,compileClasspath,implementationDependenciesMetadata,runtimeClasspath,testCompileClasspath,testImplementationDependenciesMetadata,testRuntimeClasspath -com.fasterxml.jackson.core:jackson-databind:2.18.2=apiDependenciesMetadata,compileClasspath,implementationDependenciesMetadata,runtimeClasspath,testCompileClasspath,testImplementationDependenciesMetadata,testRuntimeClasspath -com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:2.18.2=runtimeClasspath,testRuntimeClasspath -com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.18.2=apiDependenciesMetadata,compileClasspath,implementationDependenciesMetadata,runtimeClasspath,testCompileClasspath,testImplementationDependenciesMetadata,testRuntimeClasspath -com.fasterxml.jackson:jackson-bom:2.18.2=apiDependenciesMetadata,compileClasspath,implementationDependenciesMetadata,runtimeClasspath,testCompileClasspath,testImplementationDependenciesMetadata,testRuntimeClasspath +com.fasterxml.jackson.core:jackson-annotations:2.18.2=apiDependenciesMetadata,compileClasspath,implementationDependenciesMetadata,testCompileClasspath,testImplementationDependenciesMetadata +com.fasterxml.jackson.core:jackson-annotations:2.20=runtimeClasspath,testRuntimeClasspath +com.fasterxml.jackson.core:jackson-core:2.18.2=apiDependenciesMetadata,compileClasspath,implementationDependenciesMetadata,testCompileClasspath,testImplementationDependenciesMetadata +com.fasterxml.jackson.core:jackson-core:2.20.1=runtimeClasspath,testRuntimeClasspath +com.fasterxml.jackson.core:jackson-databind:2.18.2=apiDependenciesMetadata,compileClasspath,implementationDependenciesMetadata,testCompileClasspath,testImplementationDependenciesMetadata +com.fasterxml.jackson.core:jackson-databind:2.20.1=runtimeClasspath,testRuntimeClasspath +com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:2.20.1=runtimeClasspath,testRuntimeClasspath +com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.18.2=apiDependenciesMetadata,compileClasspath,implementationDependenciesMetadata,testCompileClasspath,testImplementationDependenciesMetadata +com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.20.1=runtimeClasspath,testRuntimeClasspath +com.fasterxml.jackson:jackson-bom:2.18.2=apiDependenciesMetadata,compileClasspath,implementationDependenciesMetadata,testCompileClasspath,testImplementationDependenciesMetadata +com.fasterxml.jackson:jackson-bom:2.20.1=runtimeClasspath,testRuntimeClasspath com.networknt:json-schema-validator:1.5.6=runtimeClasspath,testRuntimeClasspath +io.modelcontextprotocol.sdk:mcp-core:1.1.2=runtimeClasspath,testRuntimeClasspath +io.modelcontextprotocol.sdk:mcp-json-jackson2:1.1.2=runtimeClasspath,testRuntimeClasspath +io.projectreactor:reactor-core:3.7.0=runtimeClasspath,testRuntimeClasspath org.apiguardian:apiguardian-api:1.1.2=testCompileClasspath,testImplementationDependenciesMetadata org.jetbrains.intellij.deps:trove4j:1.0.20200330=kotlinBuildToolsApiClasspath,kotlinCompilerClasspath,kotlinKlibCommonizerClasspath org.jetbrains.kotlin:kotlin-build-common:2.1.10=kotlinBuildToolsApiClasspath @@ -43,6 +51,7 @@ org.junit.platform:junit-platform-engine:1.11.4=testRuntimeClasspath org.junit.platform:junit-platform-launcher:1.11.4=testRuntimeClasspath org.junit:junit-bom:5.11.4=testCompileClasspath,testImplementationDependenciesMetadata,testRuntimeClasspath org.opentest4j:opentest4j:1.3.0=testCompileClasspath,testImplementationDependenciesMetadata,testRuntimeClasspath +org.reactivestreams:reactive-streams:1.0.4=runtimeClasspath,testRuntimeClasspath org.slf4j:slf4j-api:2.0.16=apiDependenciesMetadata,compileClasspath,implementationDependenciesMetadata,runtimeClasspath,testCompileClasspath,testImplementationDependenciesMetadata,testRuntimeClasspath -org.yaml:snakeyaml:2.3=runtimeClasspath,testRuntimeClasspath +org.yaml:snakeyaml:2.4=runtimeClasspath,testRuntimeClasspath empty=annotationProcessor,intransitiveDependenciesMetadata,kotlinCompilerPluginClasspath,kotlinNativeCompilerPluginClasspath,kotlinScriptDefExtensions,testAnnotationProcessor,testApiDependenciesMetadata,testCompileOnlyDependenciesMetadata,testIntransitiveDependenciesMetadata,testKotlinScriptDefExtensions diff --git a/adcp-mutiny/gradle.lockfile b/adcp-mutiny/gradle.lockfile index 6cbcf4c..9393c4e 100644 --- a/adcp-mutiny/gradle.lockfile +++ b/adcp-mutiny/gradle.lockfile @@ -2,13 +2,21 @@ # Manual edits can break the build and are not advised. # This file is expected to be part of source control. com.ethlo.time:itu:1.10.3=runtimeClasspath,testRuntimeClasspath -com.fasterxml.jackson.core:jackson-annotations:2.18.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -com.fasterxml.jackson.core:jackson-core:2.18.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -com.fasterxml.jackson.core:jackson-databind:2.18.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:2.18.2=runtimeClasspath,testRuntimeClasspath -com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.18.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -com.fasterxml.jackson:jackson-bom:2.18.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath +com.fasterxml.jackson.core:jackson-annotations:2.18.2=compileClasspath,testCompileClasspath +com.fasterxml.jackson.core:jackson-annotations:2.20=runtimeClasspath,testRuntimeClasspath +com.fasterxml.jackson.core:jackson-core:2.18.2=compileClasspath,testCompileClasspath +com.fasterxml.jackson.core:jackson-core:2.20.1=runtimeClasspath,testRuntimeClasspath +com.fasterxml.jackson.core:jackson-databind:2.18.2=compileClasspath,testCompileClasspath +com.fasterxml.jackson.core:jackson-databind:2.20.1=runtimeClasspath,testRuntimeClasspath +com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:2.20.1=runtimeClasspath,testRuntimeClasspath +com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.18.2=compileClasspath,testCompileClasspath +com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.20.1=runtimeClasspath,testRuntimeClasspath +com.fasterxml.jackson:jackson-bom:2.18.2=compileClasspath,testCompileClasspath +com.fasterxml.jackson:jackson-bom:2.20.1=runtimeClasspath,testRuntimeClasspath com.networknt:json-schema-validator:1.5.6=runtimeClasspath,testRuntimeClasspath +io.modelcontextprotocol.sdk:mcp-core:1.1.2=runtimeClasspath,testRuntimeClasspath +io.modelcontextprotocol.sdk:mcp-json-jackson2:1.1.2=runtimeClasspath,testRuntimeClasspath +io.projectreactor:reactor-core:3.7.0=runtimeClasspath,testRuntimeClasspath io.smallrye.common:smallrye-common-annotation:2.8.0=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath io.smallrye.reactive:mutiny:2.7.0=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath org.apiguardian:apiguardian-api:1.1.2=testCompileClasspath @@ -22,6 +30,7 @@ org.junit.platform:junit-platform-engine:1.11.4=testRuntimeClasspath org.junit.platform:junit-platform-launcher:1.11.4=testRuntimeClasspath org.junit:junit-bom:5.11.4=testCompileClasspath,testRuntimeClasspath org.opentest4j:opentest4j:1.3.0=testCompileClasspath,testRuntimeClasspath +org.reactivestreams:reactive-streams:1.0.4=runtimeClasspath,testRuntimeClasspath org.slf4j:slf4j-api:2.0.16=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -org.yaml:snakeyaml:2.3=runtimeClasspath,testRuntimeClasspath +org.yaml:snakeyaml:2.4=runtimeClasspath,testRuntimeClasspath empty=annotationProcessor,testAnnotationProcessor diff --git a/adcp-reactor/gradle.lockfile b/adcp-reactor/gradle.lockfile index 7f1c135..8cd3d80 100644 --- a/adcp-reactor/gradle.lockfile +++ b/adcp-reactor/gradle.lockfile @@ -2,13 +2,20 @@ # Manual edits can break the build and are not advised. # This file is expected to be part of source control. com.ethlo.time:itu:1.10.3=runtimeClasspath,testRuntimeClasspath -com.fasterxml.jackson.core:jackson-annotations:2.18.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -com.fasterxml.jackson.core:jackson-core:2.18.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -com.fasterxml.jackson.core:jackson-databind:2.18.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:2.18.2=runtimeClasspath,testRuntimeClasspath -com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.18.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -com.fasterxml.jackson:jackson-bom:2.18.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath +com.fasterxml.jackson.core:jackson-annotations:2.18.2=compileClasspath,testCompileClasspath +com.fasterxml.jackson.core:jackson-annotations:2.20=runtimeClasspath,testRuntimeClasspath +com.fasterxml.jackson.core:jackson-core:2.18.2=compileClasspath,testCompileClasspath +com.fasterxml.jackson.core:jackson-core:2.20.1=runtimeClasspath,testRuntimeClasspath +com.fasterxml.jackson.core:jackson-databind:2.18.2=compileClasspath,testCompileClasspath +com.fasterxml.jackson.core:jackson-databind:2.20.1=runtimeClasspath,testRuntimeClasspath +com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:2.20.1=runtimeClasspath,testRuntimeClasspath +com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.18.2=compileClasspath,testCompileClasspath +com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.20.1=runtimeClasspath,testRuntimeClasspath +com.fasterxml.jackson:jackson-bom:2.18.2=compileClasspath,testCompileClasspath +com.fasterxml.jackson:jackson-bom:2.20.1=runtimeClasspath,testRuntimeClasspath com.networknt:json-schema-validator:1.5.6=runtimeClasspath,testRuntimeClasspath +io.modelcontextprotocol.sdk:mcp-core:1.1.2=runtimeClasspath,testRuntimeClasspath +io.modelcontextprotocol.sdk:mcp-json-jackson2:1.1.2=runtimeClasspath,testRuntimeClasspath io.projectreactor:reactor-core:3.7.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath org.apiguardian:apiguardian-api:1.1.2=testCompileClasspath org.jspecify:jspecify:1.0.0=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath @@ -22,5 +29,5 @@ org.junit:junit-bom:5.11.4=testCompileClasspath,testRuntimeClasspath org.opentest4j:opentest4j:1.3.0=testCompileClasspath,testRuntimeClasspath org.reactivestreams:reactive-streams:1.0.4=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath org.slf4j:slf4j-api:2.0.16=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -org.yaml:snakeyaml:2.3=runtimeClasspath,testRuntimeClasspath +org.yaml:snakeyaml:2.4=runtimeClasspath,testRuntimeClasspath empty=annotationProcessor,testAnnotationProcessor diff --git a/adcp-server/src/main/java/org/adcontextprotocol/adcp/server/AdcpContext.java b/adcp-server/src/main/java/org/adcontextprotocol/adcp/server/AdcpContext.java new file mode 100644 index 0000000..3f55da5 --- /dev/null +++ b/adcp-server/src/main/java/org/adcontextprotocol/adcp/server/AdcpContext.java @@ -0,0 +1,27 @@ +package org.adcontextprotocol.adcp.server; + +import org.adcontextprotocol.adcp.AdcpVersion; +import org.jspecify.annotations.Nullable; + +import java.util.Map; + +/** + * Context passed to {@link AdcpPlatform} tool handlers. + * + *

Carries per-request metadata: the caller's identity, the negotiated + * protocol version, and any headers the handler might need. + * + * @param adcpVersion the protocol version from the request envelope + * @param headers all inbound request headers + * @param requestId the MCP request ID (for correlation) + */ +public record AdcpContext( + @Nullable AdcpVersion adcpVersion, + Map headers, + @Nullable String requestId +) { + + public AdcpContext { + headers = Map.copyOf(headers); + } +} diff --git a/adcp-server/src/main/java/org/adcontextprotocol/adcp/server/AdcpPlatform.java b/adcp-server/src/main/java/org/adcontextprotocol/adcp/server/AdcpPlatform.java new file mode 100644 index 0000000..6efe3cc --- /dev/null +++ b/adcp-server/src/main/java/org/adcontextprotocol/adcp/server/AdcpPlatform.java @@ -0,0 +1,93 @@ +package org.adcontextprotocol.adcp.server; + +import org.adcontextprotocol.adcp.error.UnsupportedTaskError; + +import java.util.Map; + +/** + * Service Provider Interface for AdCP agent implementations. + * + *

Adopters extend this class, override {@link #supportedTools()} to + * declare which tools to advertise via MCP {@code tools/list}, and + * override {@link #handleTool(String, Map, AdcpContext)} to dispatch + * them. Only tools returned by {@link #supportedTools()} are registered + * with the MCP server — unregistered tools are never advertised. + * + *

Each method receives a typed request and an {@link AdcpContext} + * with per-request metadata (protocol version, headers, etc.). + * + *

Example: + *

{@code
+ * public class MyPlatform extends AdcpPlatform {
+ *     @Override
+ *     public Set supportedTools() {
+ *         return Set.of("get_products", "get_creatives");
+ *     }
+ *
+ *     @Override
+ *     public Object handleTool(String toolName, Map request, AdcpContext ctx) {
+ *         return switch (toolName) {
+ *             case "get_products" -> getProducts(request, ctx);
+ *             case "get_creatives" -> getCreatives(request, ctx);
+ *             default -> super.handleTool(toolName, request, ctx);
+ *         };
+ *     }
+ * }
+ * }
+ */ +public abstract class AdcpPlatform { + + /** + * Dispatches a tool call by name. Override this to handle specific tools. + * + *

The default implementation throws {@link UnsupportedTaskError}, + * signaling that the tool is not implemented. + * + * @param toolName the MCP tool name (e.g. "get_products") + * @param request the deserialized request arguments + * @param ctx per-request context + * @return the response object (will be serialized by the framework) + * @throws UnsupportedTaskError if the tool is not implemented + */ + public Object handleTool(String toolName, Map request, AdcpContext ctx) { + throw new UnsupportedTaskError(toolName); + } + + /** + * Returns the set of tool names this platform supports. + * + *

Override this to declare which tools your platform advertises. + * The default returns an empty set (no tools). + * + * @return tool names supported by this platform + */ + public java.util.Set supportedTools() { + return java.util.Set.of(); + } + + /** + * Returns human-readable descriptions for each tool, keyed by tool name. + * + *

Override this to provide descriptions that help MCP clients + * (and LLMs) understand when to invoke each tool. If a tool has no + * entry in this map, its name is used as the description. + * + * @return map of tool name → description + */ + public java.util.Map toolDescriptions() { + return java.util.Map.of(); + } + + /** + * Returns JSON Schemas for each tool's input, keyed by tool name. + * + *

Override this to expose typed validation schemas to MCP clients. + * If a tool has no entry, a permissive open-object schema is used. + * Each value must be a valid {@link io.modelcontextprotocol.spec.McpSchema.JsonSchema}. + * + * @return map of tool name → input schema + */ + public java.util.Map toolSchemas() { + return java.util.Map.of(); + } +} diff --git a/adcp-server/src/main/java/org/adcontextprotocol/adcp/server/AdcpServerBuilder.java b/adcp-server/src/main/java/org/adcontextprotocol/adcp/server/AdcpServerBuilder.java new file mode 100644 index 0000000..b453e97 --- /dev/null +++ b/adcp-server/src/main/java/org/adcontextprotocol/adcp/server/AdcpServerBuilder.java @@ -0,0 +1,231 @@ +package org.adcontextprotocol.adcp.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.adcontextprotocol.adcp.AdcpVersion; +import org.adcontextprotocol.adcp.schema.AdcpObjectMapperFactory; +import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +/** + * Builds and wires an MCP server backed by an {@link AdcpPlatform}. + * + *

Introspects the platform's {@link AdcpPlatform#supportedTools()} to + * register MCP tool handlers. Only supported tools are advertised via + * {@code tools/list}. + * + *

Usage: + *

{@code
+ * McpServerTransportProvider transport = ...;
+ * AdcpServerBuilder.create(myPlatform)
+ *     .transport(transport)
+ *     .build()
+ *     .initialize();
+ * }
+ */ +public final class AdcpServerBuilder { + + private static final Logger log = LoggerFactory.getLogger(AdcpServerBuilder.class); + + private final AdcpPlatform platform; + private @Nullable McpServerTransportProvider transport; + private @Nullable ObjectMapper objectMapper; + private @Nullable AdcpVersion adcpVersion; + private String serverName = "adcp-java-sdk"; + private String serverVersion = "0.1.0"; + + private AdcpServerBuilder(AdcpPlatform platform) { + this.platform = Objects.requireNonNull(platform, "platform"); + } + + /** Creates a new server builder for the given platform. */ + public static AdcpServerBuilder create(AdcpPlatform platform) { + return new AdcpServerBuilder(platform); + } + + /** Sets the MCP transport provider (required). */ + public AdcpServerBuilder transport(McpServerTransportProvider transport) { + this.transport = Objects.requireNonNull(transport); + return this; + } + + /** Overrides the Jackson ObjectMapper. */ + public AdcpServerBuilder objectMapper(ObjectMapper objectMapper) { + this.objectMapper = Objects.requireNonNull(objectMapper); + return this; + } + + /** Sets the advertised server name. */ + public AdcpServerBuilder serverName(String serverName) { + this.serverName = Objects.requireNonNull(serverName); + return this; + } + + /** Sets the advertised server version. */ + public AdcpServerBuilder serverVersion(String serverVersion) { + this.serverVersion = Objects.requireNonNull(serverVersion); + return this; + } + + /** Sets the AdCP version for response envelopes. */ + public AdcpServerBuilder adcpVersion(AdcpVersion adcpVersion) { + this.adcpVersion = adcpVersion; + return this; + } + + /** + * Builds and returns the MCP server. Call {@code initialize()} on the + * result to start accepting connections. + */ + public McpSyncServer build() { + if (transport == null) { + throw new org.adcontextprotocol.adcp.error.ConfigurationError( + "McpServerTransportProvider is required", "transport"); + } + + ObjectMapper om = objectMapper != null + ? objectMapper + : AdcpObjectMapperFactory.create(); + + Set tools = platform.supportedTools(); + Map descriptions = platform.toolDescriptions(); + Map schemas = platform.toolSchemas(); + log.info("Building AdCP server with {} tool(s): {}", tools.size(), tools); + + // Build the MCP server with tool handlers + var spec = McpServer.sync(transport) + .serverInfo(serverName, serverVersion); + + // Permissive open-object schema used when the platform doesn't + // provide a typed schema for a tool. + McpSchema.JsonSchema defaultSchema = new McpSchema.JsonSchema( + "object", Map.of(), List.of(), true, null, null); + + for (String toolName : tools) { + String description = descriptions.getOrDefault(toolName, toolName); + McpSchema.JsonSchema inputSchema = schemas.getOrDefault(toolName, defaultSchema); + McpSchema.Tool tool = McpSchema.Tool.builder() + .name(toolName) + .description(description) + .inputSchema(inputSchema) + .build(); + spec.toolCall(tool, + (exchange, request) -> handleToolCall(om, toolName, request)); + } + + return spec.build(); + } + + /** + * Dispatches a tool call through the platform, handling version extraction, + * envelope stripping, error wrapping, and response serialization. + * + *

Package-private for testing ({@code AdcpServerBuilderTest}). + */ + @SuppressWarnings("unchecked") + McpSchema.CallToolResult handleToolCall( + ObjectMapper om, String toolName, McpSchema.CallToolRequest request) { + try { + Map args = request.arguments() != null + ? new java.util.LinkedHashMap<>(request.arguments()) + : new java.util.LinkedHashMap<>(); + + AdcpVersion version = extractVersion(args); + + // Strip version envelope fields before passing to platform + args.remove("adcp_major_version"); + args.remove("adcp_version"); + + AdcpContext ctx = new AdcpContext(version, Map.of(), null); + + Object response = platform.handleTool(toolName, args, ctx); + + String json = om.writeValueAsString(response); + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent(json)), + false, null, Map.of()); + } catch (org.adcontextprotocol.adcp.error.AdcpError e) { + // Known application errors — surface the stable code plus a + // brief, sanitized message. The full message is logged server-side. + log.warn("Tool call failed ({}) [{}]: {}", toolName, e.code(), e.getMessage()); + String safeError; + try { + safeError = om.writeValueAsString( + Map.of("error", e.code(), + "message", sanitizeErrorMessage(e.getMessage()))); + } catch (Exception ignored) { + // e.code() is always an enum-like constant, but use a + // fixed string to be absolutely safe against JSON injection. + safeError = "{\"error\":\"internal_error\"}"; + } + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent(safeError)), + true, null, Map.of()); + } catch (Exception e) { + // Unknown errors — do NOT leak internal details to remote callers + log.error("Tool call failed: {}", toolName, e); + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent("{\"error\":\"internal error\"}")), + true, null, Map.of()); + } + } + + private static final int MAX_ERROR_MESSAGE_LENGTH = 500; + + /** + * Sanitizes error messages before sending to remote callers. + * Prevents leaking internal details (stack traces, SQL, file paths). + */ + private static String sanitizeErrorMessage(String raw) { + if (raw == null) return "(no error detail)"; + String truncated = raw.length() > MAX_ERROR_MESSAGE_LENGTH + ? raw.substring(0, MAX_ERROR_MESSAGE_LENGTH) + "..." + : raw; + // Strip control characters except tab and newline + return truncated.replaceAll("[\\p{Cc}&&[^\t\n]]", ""); + } + + /** Package-private for testing. */ + @Nullable AdcpVersion extractVersion(Map args) { + Object majorRaw = args.get("adcp_major_version"); + int major; + if (majorRaw instanceof Number num) { + major = num.intValue(); + } else if (majorRaw instanceof String s) { + try { + major = Integer.parseInt(s); + } catch (NumberFormatException e) { + return adcpVersion; + } + } else { + return adcpVersion; + } + if (major < 1 || major > 99) { + throw new org.adcontextprotocol.adcp.error.VersionUnsupportedError( + null, "Unsupported AdCP major version: " + major, + String.valueOf(major), null); + } + // AdCP back-compat: versions < 3 default to v1 semantics rather + // than refusing the request (matches Python adcp.server behavior). + if (major < 3) { + log.debug("Client sent adcp_major_version={}, defaulting to v1 semantics", major); + return new AdcpVersion(major, null); + } + String minor = args.get("adcp_version") instanceof String s ? s : null; + // Guard against unbounded strings from untrusted input + if (minor != null && minor.length() > 20) { + log.warn("Rejecting oversized adcp_version field ({} chars)", minor.length()); + minor = null; + } + return new AdcpVersion(major, minor); + } +} diff --git a/adcp-server/src/test/java/org/adcontextprotocol/adcp/server/AdcpPlatformTest.java b/adcp-server/src/test/java/org/adcontextprotocol/adcp/server/AdcpPlatformTest.java new file mode 100644 index 0000000..c35478b --- /dev/null +++ b/adcp-server/src/test/java/org/adcontextprotocol/adcp/server/AdcpPlatformTest.java @@ -0,0 +1,102 @@ +package org.adcontextprotocol.adcp.server; + +import org.adcontextprotocol.adcp.AdcpVersion; +import org.adcontextprotocol.adcp.error.UnsupportedTaskError; +import org.junit.jupiter.api.Test; + +import java.util.Map; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Tests for {@link AdcpPlatform} and {@link AdcpContext}. + */ +class AdcpPlatformTest { + + @Test + void default_handleTool_throws_unsupported() { + AdcpPlatform platform = new AdcpPlatform() { + @Override + public Set supportedTools() { + return Set.of(); + } + }; + + AdcpContext ctx = new AdcpContext(AdcpVersion.V3, Map.of(), null); + + assertThrows(UnsupportedTaskError.class, + () -> platform.handleTool("get_products", Map.of(), ctx)); + } + + @Test + void custom_platform_handles_tool() { + AdcpPlatform platform = new AdcpPlatform() { + @Override + public Set supportedTools() { + return Set.of("get_products"); + } + + @Override + public Object handleTool(String toolName, Map request, AdcpContext ctx) { + if ("get_products".equals(toolName)) { + return Map.of("products", java.util.List.of()); + } + return super.handleTool(toolName, request, ctx); + } + }; + + AdcpContext ctx = new AdcpContext(AdcpVersion.V3, Map.of(), "req-1"); + Object result = platform.handleTool("get_products", Map.of(), ctx); + + assertNotNull(result); + assertInstanceOf(Map.class, result); + } + + @Test + void context_headers_are_immutable() { + var headers = new java.util.HashMap(); + headers.put("Authorization", "Bearer tok"); + + AdcpContext ctx = new AdcpContext(AdcpVersion.V3, headers, null); + + assertThrows(UnsupportedOperationException.class, + () -> ctx.headers().put("X-Evil", "val")); + } + + @Test + void context_records_request_id() { + AdcpContext ctx = new AdcpContext(null, Map.of(), "req-123"); + assertEquals("req-123", ctx.requestId()); + } + + @Test + void default_toolDescriptions_returns_empty() { + AdcpPlatform platform = new AdcpPlatform() {}; + assertTrue(platform.toolDescriptions().isEmpty()); + } + + @Test + void default_toolSchemas_returns_empty() { + AdcpPlatform platform = new AdcpPlatform() {}; + assertTrue(platform.toolSchemas().isEmpty()); + } + + @Test + void custom_toolDescriptions() { + AdcpPlatform platform = new AdcpPlatform() { + @Override + public Set supportedTools() { + return Set.of("get_products"); + } + + @Override + public Map toolDescriptions() { + return Map.of("get_products", "Retrieves product catalog for an advertiser"); + } + }; + + assertEquals("Retrieves product catalog for an advertiser", + platform.toolDescriptions().get("get_products")); + } +} diff --git a/adcp-server/src/test/java/org/adcontextprotocol/adcp/server/AdcpServerBuilderTest.java b/adcp-server/src/test/java/org/adcontextprotocol/adcp/server/AdcpServerBuilderTest.java new file mode 100644 index 0000000..bbdc6de --- /dev/null +++ b/adcp-server/src/test/java/org/adcontextprotocol/adcp/server/AdcpServerBuilderTest.java @@ -0,0 +1,255 @@ +package org.adcontextprotocol.adcp.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpSchema; +import org.adcontextprotocol.adcp.AdcpVersion; +import org.adcontextprotocol.adcp.schema.AdcpObjectMapperFactory; +import org.junit.jupiter.api.Test; + +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Tests for {@link AdcpServerBuilder}: handleToolCall dispatch, version + * extraction/stripping, error wrapping, and back-compat behavior. + * + *

These tests exercise the builder's internal wiring directly — the + * same code path that MCP tool calls follow at runtime. + */ +class AdcpServerBuilderTest { + + private final ObjectMapper om = AdcpObjectMapperFactory.create(); + + static class EchoPlatform extends AdcpPlatform { + Map lastArgs; + AdcpContext lastContext; + + @Override + public Set supportedTools() { + return Set.of("echo"); + } + + @Override + public Object handleTool(String toolName, Map request, AdcpContext ctx) { + lastArgs = request; + lastContext = ctx; + return Map.of("echo", request, "tool", toolName); + } + } + + private AdcpServerBuilder builderWith(AdcpPlatform platform) { + return AdcpServerBuilder.create(platform).adcpVersion(AdcpVersion.V3); + } + + // -- handleToolCall -- + + @Test + void handleToolCall_dispatches_and_returns_json_result() { + EchoPlatform platform = new EchoPlatform(); + AdcpServerBuilder builder = builderWith(platform); + + McpSchema.CallToolRequest request = new McpSchema.CallToolRequest( + "echo", Map.of("query", "test")); + + McpSchema.CallToolResult result = builder.handleToolCall(om, "echo", request); + + assertFalse(result.isError()); + assertNotNull(result.content()); + assertFalse(result.content().isEmpty()); + assertInstanceOf(McpSchema.TextContent.class, result.content().getFirst()); + String json = ((McpSchema.TextContent) result.content().getFirst()).text(); + assertTrue(json.contains("\"echo\""), "Result should contain echo field: " + json); + assertTrue(json.contains("\"query\""), "Result should contain query arg: " + json); + } + + @Test + void handleToolCall_strips_version_envelope_before_dispatch() { + EchoPlatform platform = new EchoPlatform(); + AdcpServerBuilder builder = builderWith(platform); + + Map args = new LinkedHashMap<>(); + args.put("adcp_major_version", 3); + args.put("adcp_version", "3.1"); + args.put("query", "test"); + + McpSchema.CallToolRequest request = new McpSchema.CallToolRequest("echo", args); + builder.handleToolCall(om, "echo", request); + + // Version fields should be stripped before reaching the platform + assertNotNull(platform.lastArgs); + assertFalse(platform.lastArgs.containsKey("adcp_major_version"), + "adcp_major_version should be stripped"); + assertFalse(platform.lastArgs.containsKey("adcp_version"), + "adcp_version should be stripped"); + assertEquals("test", platform.lastArgs.get("query"), + "Non-envelope args should be preserved"); + } + + @Test + void handleToolCall_extracts_version_into_context() { + EchoPlatform platform = new EchoPlatform(); + AdcpServerBuilder builder = builderWith(platform); + + Map args = new LinkedHashMap<>(); + args.put("adcp_major_version", 3); + args.put("adcp_version", "3.1"); + + builder.handleToolCall(om, "echo", + new McpSchema.CallToolRequest("echo", args)); + + assertNotNull(platform.lastContext); + assertNotNull(platform.lastContext.adcpVersion()); + assertEquals(3, platform.lastContext.adcpVersion().majorVersion()); + assertEquals("3.1", platform.lastContext.adcpVersion().minorVersion()); + } + + @Test + void handleToolCall_wraps_adcp_errors() { + AdcpPlatform failingPlatform = new AdcpPlatform() { + @Override + public Set supportedTools() { + return Set.of("fail"); + } + + @Override + public Object handleTool(String toolName, Map request, + AdcpContext ctx) { + throw new org.adcontextprotocol.adcp.error.UnsupportedTaskError("fail"); + } + }; + + AdcpServerBuilder builder = builderWith(failingPlatform); + McpSchema.CallToolResult result = builder.handleToolCall(om, "fail", + new McpSchema.CallToolRequest("fail", Map.of())); + + assertTrue(result.isError(), "Should be marked as error"); + String errorJson = ((McpSchema.TextContent) result.content().getFirst()).text(); + assertTrue(errorJson.contains("UNSUPPORTED_TASK"), + "Error should contain stable error code: " + errorJson); + } + + @Test + void handleToolCall_wraps_unexpected_exceptions() { + AdcpPlatform throwingPlatform = new AdcpPlatform() { + @Override + public Set supportedTools() { + return Set.of("boom"); + } + + @Override + public Object handleTool(String toolName, Map request, + AdcpContext ctx) { + throw new RuntimeException("Unexpected error with sensitive details"); + } + }; + + AdcpServerBuilder builder = builderWith(throwingPlatform); + McpSchema.CallToolResult result = builder.handleToolCall(om, "boom", + new McpSchema.CallToolRequest("boom", Map.of())); + + assertTrue(result.isError()); + String errorJson = ((McpSchema.TextContent) result.content().getFirst()).text(); + assertTrue(errorJson.contains("internal error"), + "Unknown errors should be wrapped as internal error: " + errorJson); + assertFalse(errorJson.contains("sensitive"), + "Internal details must not leak: " + errorJson); + } + + @Test + void handleToolCall_handles_null_arguments() { + EchoPlatform platform = new EchoPlatform(); + AdcpServerBuilder builder = builderWith(platform); + + McpSchema.CallToolResult result = builder.handleToolCall(om, "echo", + new McpSchema.CallToolRequest("echo", null)); + + assertFalse(result.isError()); + assertNotNull(platform.lastArgs); + assertTrue(platform.lastArgs.isEmpty()); + } + + // -- extractVersion -- + + @Test + void extractVersion_parses_integer_major() { + AdcpServerBuilder builder = builderWith(new EchoPlatform()); + Map args = new LinkedHashMap<>( + Map.of("adcp_major_version", 3, "adcp_version", "3.1")); + + AdcpVersion version = builder.extractVersion(args); + + assertNotNull(version); + assertEquals(3, version.majorVersion()); + assertEquals("3.1", version.minorVersion()); + } + + @Test + void extractVersion_parses_string_major() { + AdcpServerBuilder builder = builderWith(new EchoPlatform()); + Map args = new LinkedHashMap<>( + Map.of("adcp_major_version", "3")); + + AdcpVersion version = builder.extractVersion(args); + + assertNotNull(version); + assertEquals(3, version.majorVersion()); + } + + @Test + void extractVersion_defaults_to_v1_semantics_for_major_lt_3() { + AdcpServerBuilder builder = builderWith(new EchoPlatform()); + Map args = new LinkedHashMap<>( + Map.of("adcp_major_version", 1)); + + AdcpVersion version = builder.extractVersion(args); + + assertNotNull(version); + assertEquals(1, version.majorVersion()); + } + + @Test + void extractVersion_rejects_major_0() { + AdcpServerBuilder builder = builderWith(new EchoPlatform()); + Map args = new LinkedHashMap<>( + Map.of("adcp_major_version", 0)); + + assertThrows(org.adcontextprotocol.adcp.error.VersionUnsupportedError.class, + () -> builder.extractVersion(args)); + } + + @Test + void extractVersion_rejects_major_100() { + AdcpServerBuilder builder = builderWith(new EchoPlatform()); + Map args = new LinkedHashMap<>( + Map.of("adcp_major_version", 100)); + + assertThrows(org.adcontextprotocol.adcp.error.VersionUnsupportedError.class, + () -> builder.extractVersion(args)); + } + + @Test + void extractVersion_returns_default_when_no_version_field() { + AdcpServerBuilder builder = builderWith(new EchoPlatform()); + + AdcpVersion version = builder.extractVersion(Map.of("query", "test")); + + assertEquals(AdcpVersion.V3, version); + } + + @Test + void extractVersion_rejects_oversized_minor_version() { + AdcpServerBuilder builder = builderWith(new EchoPlatform()); + Map args = new LinkedHashMap<>( + Map.of("adcp_major_version", 3, "adcp_version", "x".repeat(100))); + + AdcpVersion version = builder.extractVersion(args); + + assertNotNull(version); + assertEquals(3, version.majorVersion()); + assertNull(version.minorVersion(), "Oversized minor should be rejected"); + } +} diff --git a/adcp-testing/build.gradle.kts b/adcp-testing/build.gradle.kts index 678f1d3..fdcc9cd 100644 --- a/adcp-testing/build.gradle.kts +++ b/adcp-testing/build.gradle.kts @@ -16,4 +16,7 @@ dependencies { // JUnit Jupiter is part of the public surface — adopters write tests against // AdcpAgentExtension on the api scope. api(libs.junit.jupiter.api) + + // Server module is needed for integration tests (AdcpPlatform + AdcpServerBuilder) + testImplementation(project(":adcp-server")) } diff --git a/adcp-testing/gradle.lockfile b/adcp-testing/gradle.lockfile index fcb2d5b..b8b4b03 100644 --- a/adcp-testing/gradle.lockfile +++ b/adcp-testing/gradle.lockfile @@ -1,14 +1,24 @@ # This is a Gradle generated file for dependency locking. # Manual edits can break the build and are not advised. # This file is expected to be part of source control. -com.ethlo.time:itu:1.10.3=runtimeClasspath,testRuntimeClasspath -com.fasterxml.jackson.core:jackson-annotations:2.18.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -com.fasterxml.jackson.core:jackson-core:2.18.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -com.fasterxml.jackson.core:jackson-databind:2.18.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:2.18.2=runtimeClasspath,testRuntimeClasspath -com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.18.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -com.fasterxml.jackson:jackson-bom:2.18.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -com.networknt:json-schema-validator:1.5.6=runtimeClasspath,testRuntimeClasspath +com.ethlo.time:itu:1.10.3=runtimeClasspath +com.ethlo.time:itu:1.14.0=testCompileClasspath,testRuntimeClasspath +com.fasterxml.jackson.core:jackson-annotations:2.18.2=compileClasspath +com.fasterxml.jackson.core:jackson-annotations:2.20=runtimeClasspath,testCompileClasspath,testRuntimeClasspath +com.fasterxml.jackson.core:jackson-core:2.18.2=compileClasspath +com.fasterxml.jackson.core:jackson-core:2.20.1=runtimeClasspath,testCompileClasspath,testRuntimeClasspath +com.fasterxml.jackson.core:jackson-databind:2.18.2=compileClasspath +com.fasterxml.jackson.core:jackson-databind:2.20.1=runtimeClasspath,testCompileClasspath,testRuntimeClasspath +com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:2.20.1=runtimeClasspath,testCompileClasspath,testRuntimeClasspath +com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.18.2=compileClasspath +com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.20.1=runtimeClasspath,testCompileClasspath,testRuntimeClasspath +com.fasterxml.jackson:jackson-bom:2.18.2=compileClasspath +com.fasterxml.jackson:jackson-bom:2.20.1=runtimeClasspath,testCompileClasspath,testRuntimeClasspath +com.networknt:json-schema-validator:1.5.6=runtimeClasspath +com.networknt:json-schema-validator:2.0.0=testCompileClasspath,testRuntimeClasspath +io.modelcontextprotocol.sdk:mcp-core:1.1.2=runtimeClasspath,testCompileClasspath,testRuntimeClasspath +io.modelcontextprotocol.sdk:mcp-json-jackson2:1.1.2=runtimeClasspath,testCompileClasspath,testRuntimeClasspath +io.projectreactor:reactor-core:3.7.0=runtimeClasspath,testCompileClasspath,testRuntimeClasspath org.apiguardian:apiguardian-api:1.1.2=compileClasspath,testCompileClasspath org.jspecify:jspecify:1.0.0=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath org.junit.jupiter:junit-jupiter-api:5.11.4=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath @@ -19,6 +29,8 @@ org.junit.platform:junit-platform-engine:1.11.4=testRuntimeClasspath org.junit.platform:junit-platform-launcher:1.11.4=testRuntimeClasspath org.junit:junit-bom:5.11.4=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath org.opentest4j:opentest4j:1.3.0=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -org.slf4j:slf4j-api:2.0.16=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -org.yaml:snakeyaml:2.3=runtimeClasspath,testRuntimeClasspath +org.reactivestreams:reactive-streams:1.0.4=runtimeClasspath,testCompileClasspath,testRuntimeClasspath +org.slf4j:slf4j-api:2.0.16=compileClasspath,runtimeClasspath +org.slf4j:slf4j-api:2.0.17=testCompileClasspath,testRuntimeClasspath +org.yaml:snakeyaml:2.4=runtimeClasspath,testCompileClasspath,testRuntimeClasspath empty=annotationProcessor,testAnnotationProcessor diff --git a/adcp-testing/src/test/java/org/adcontextprotocol/adcp/testing/AdcpClientIntegrationTest.java b/adcp-testing/src/test/java/org/adcontextprotocol/adcp/testing/AdcpClientIntegrationTest.java new file mode 100644 index 0000000..1c198b3 --- /dev/null +++ b/adcp-testing/src/test/java/org/adcontextprotocol/adcp/testing/AdcpClientIntegrationTest.java @@ -0,0 +1,95 @@ +package org.adcontextprotocol.adcp.testing; + +import org.adcontextprotocol.adcp.AdcpClient; +import org.adcontextprotocol.adcp.AdcpVersion; +import org.adcontextprotocol.adcp.AgentConfig; +import org.adcontextprotocol.adcp.Protocol; +import org.adcontextprotocol.adcp.http.SsrfPolicy; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import java.net.URI; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Integration test that uses {@link AdcpClient} to call the + * {@code @adcp/sdk/mock-server} sidecar when available. + * + *

Skipped unless {@code ADCP_MOCK_SERVER_URL} is set. CI provides it + * via the storyboard workflow. + * + *

Validates the full caller-side stack: + * {@code AdcpClient} → {@code ProtocolClient} → {@code McpCaller} → + * MCP transport → mock-server. + */ +@EnabledIfEnvironmentVariable( + named = "ADCP_MOCK_SERVER_URL", + matches = ".+", + disabledReason = "Set ADCP_MOCK_SERVER_URL to run; CI sets it automatically" +) +class AdcpClientIntegrationTest { + + private static URI mockServerUri() { + return URI.create(System.getenv("ADCP_MOCK_SERVER_URL")); + } + + @Test + void client_builder_configures_against_mock_server() { + AgentConfig agent = AgentConfig.mcp("mock", mockServerUri()); + + try (AdcpClient client = AdcpClient.builder() + .agent(agent) + .adcpVersion(AdcpVersion.V3) + .ssrfPolicy(SsrfPolicy.permissive()) + .build()) { + assertNotNull(client); + assertEquals(Protocol.MCP, client.agent().protocol()); + assertEquals(mockServerUri(), client.agent().agentUri()); + } + } + + /** + * Exercises the full caller stack against a live MCP-speaking server. + * + *

The current {@code @adcp/sdk} mock-server is a REST stub, not an + * MCP server, so this test is guarded behind a separate env var + * ({@code ADCP_MCP_SERVER_URL}) until the mock-server gains MCP + * support. + */ + @Test + @EnabledIfEnvironmentVariable( + named = "ADCP_MCP_SERVER_URL", + matches = ".+", + disabledReason = "Set ADCP_MCP_SERVER_URL to run against an MCP-speaking server" + ) + @SuppressWarnings("unchecked") + void callTool_get_adcp_capabilities_returns_response() { + URI mcpUri = URI.create(System.getenv("ADCP_MCP_SERVER_URL")); + AgentConfig agent = AgentConfig.mcp("mock", mcpUri); + + try (AdcpClient client = AdcpClient.builder() + .agent(agent) + .adcpVersion(AdcpVersion.V3) + .ssrfPolicy(SsrfPolicy.permissive()) + .build()) { + Map result = client.callTool( + "get_adcp_capabilities", Map.of(), Map.class); + + // Validate spec shape — not just non-null + assertNotNull(result, "get_adcp_capabilities should return a response"); + assertFalse(result.isEmpty(), + "Response should contain at least one field"); + + // The response should carry either 'capabilities' or version fields + // depending on the mock-server implementation + boolean hasCapabilities = result.containsKey("capabilities"); + boolean hasVersion = result.containsKey("adcp_version") + || result.containsKey("adcp_major_version"); + assertTrue(hasCapabilities || hasVersion, + "Response should contain 'capabilities' or version fields, got: " + + result.keySet()); + } + } +} diff --git a/adcp-testing/src/test/java/org/adcontextprotocol/adcp/testing/ServerBuilderRoundTripTest.java b/adcp-testing/src/test/java/org/adcontextprotocol/adcp/testing/ServerBuilderRoundTripTest.java new file mode 100644 index 0000000..f2ec82f --- /dev/null +++ b/adcp-testing/src/test/java/org/adcontextprotocol/adcp/testing/ServerBuilderRoundTripTest.java @@ -0,0 +1,219 @@ +package org.adcontextprotocol.adcp.testing; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.adcontextprotocol.adcp.AdcpVersion; +import org.adcontextprotocol.adcp.error.UnsupportedTaskError; +import org.adcontextprotocol.adcp.schema.AdcpObjectMapperFactory; +import org.adcontextprotocol.adcp.server.AdcpContext; +import org.adcontextprotocol.adcp.server.AdcpPlatform; +import org.adcontextprotocol.adcp.server.AdcpServerBuilder; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Integration test that verifies the full server-side wiring: + * {@link AdcpPlatform} → {@link AdcpServerBuilder} → MCP server. + * + *

This test validates that the SDK correctly: + *

    + *
  • Introspects supported tools from the platform
  • + *
  • Builds an MCP server with the correct tool registrations
  • + *
  • Dispatches tool calls through the builder's handleToolCall path
  • + *
  • Handles errors correctly
  • + *
+ */ +class ServerBuilderRoundTripTest { + + /** + * A simple test platform that supports get_products and list_accounts. + */ + static class TestPlatform extends AdcpPlatform { + boolean getProductsCalled; + boolean listAccountsCalled; + AdcpContext lastContext; + + @Override + public Set supportedTools() { + return Set.of("get_products", "list_accounts"); + } + + @Override + public Object handleTool(String toolName, Map request, AdcpContext ctx) { + lastContext = ctx; + return switch (toolName) { + case "get_products" -> { + getProductsCalled = true; + yield Map.of("products", List.of( + Map.of("id", "p1", "name", "Product 1"), + Map.of("id", "p2", "name", "Product 2"))); + } + case "list_accounts" -> { + listAccountsCalled = true; + yield Map.of("accounts", List.of()); + } + default -> super.handleTool(toolName, request, ctx); + }; + } + } + + @Test + void builder_creates_server_with_correct_tool_count() { + TestPlatform platform = new TestPlatform(); + + assertEquals(2, platform.supportedTools().size()); + assertTrue(platform.supportedTools().contains("get_products")); + assertTrue(platform.supportedTools().contains("list_accounts")); + } + + @Test + void builder_build_creates_mcp_server() { + TestPlatform platform = new TestPlatform(); + // Use a StubTransport so we exercise the builder.build() path + McpServerTransportProvider transport = new StubMcpTransport(); + + McpSyncServer server = AdcpServerBuilder.create(platform) + .transport(transport) + .serverName("test-server") + .serverVersion("0.0.1") + .adcpVersion(AdcpVersion.V3) + .build(); + + assertNotNull(server, "build() should return a non-null MCP server"); + } + + @Test + void platform_dispatches_get_products() { + TestPlatform platform = new TestPlatform(); + AdcpContext ctx = new AdcpContext(AdcpVersion.V3, Map.of(), "req-1"); + + @SuppressWarnings("unchecked") + Map result = (Map) + platform.handleTool("get_products", Map.of(), ctx); + + assertTrue(platform.getProductsCalled); + assertNotNull(result.get("products")); + assertInstanceOf(List.class, result.get("products")); + } + + @Test + void platform_dispatches_list_accounts() { + TestPlatform platform = new TestPlatform(); + AdcpContext ctx = new AdcpContext(AdcpVersion.V3, Map.of(), "req-2"); + + @SuppressWarnings("unchecked") + Map result = (Map) + platform.handleTool("list_accounts", Map.of(), ctx); + + assertTrue(platform.listAccountsCalled); + assertNotNull(result.get("accounts")); + } + + @Test + void platform_rejects_unsupported_tool() { + TestPlatform platform = new TestPlatform(); + AdcpContext ctx = new AdcpContext(AdcpVersion.V3, Map.of(), "req-3"); + + UnsupportedTaskError error = assertThrows(UnsupportedTaskError.class, + () -> platform.handleTool("sync_creatives", Map.of(), ctx)); + + assertTrue(error.getMessage().contains("sync_creatives")); + } + + @Test + void platform_receives_context_with_version_and_headers() { + TestPlatform platform = new TestPlatform(); + Map headers = Map.of( + "Authorization", "Bearer test-token", + "X-Request-Id", "req-456"); + + AdcpContext ctx = new AdcpContext(AdcpVersion.V3_1, headers, "req-456"); + platform.handleTool("get_products", Map.of(), ctx); + + assertNotNull(platform.lastContext); + assertEquals(AdcpVersion.V3_1, platform.lastContext.adcpVersion()); + assertEquals("Bearer test-token", platform.lastContext.headers().get("Authorization")); + assertEquals("req-456", platform.lastContext.requestId()); + } + + @Test + void server_builder_requires_transport() { + TestPlatform platform = new TestPlatform(); + + assertThrows(Exception.class, () -> + AdcpServerBuilder.create(platform).build()); + } + + @Test + void server_builder_accepts_custom_server_info() { + TestPlatform platform = new TestPlatform(); + + AdcpServerBuilder builder = AdcpServerBuilder.create(platform) + .serverName("test-agent") + .serverVersion("1.0.0") + .adcpVersion(AdcpVersion.V3); + + assertNotNull(builder); + } + + @Test + void version_extraction_strips_envelope_from_args() { + TestPlatform platform = new TestPlatform(); + Map argsWithVersion = new java.util.LinkedHashMap<>(Map.of( + "adcp_major_version", 3, + "adcp_version", "3.1", + "query", "test")); + + AdcpContext ctx = new AdcpContext(AdcpVersion.V3_1, Map.of(), null); + Object result = platform.handleTool("get_products", argsWithVersion, ctx); + assertNotNull(result); + } + + @Test + void multiple_tools_independent_dispatch() { + TestPlatform platform = new TestPlatform(); + AdcpContext ctx = new AdcpContext(AdcpVersion.V3, Map.of(), null); + + assertFalse(platform.getProductsCalled); + assertFalse(platform.listAccountsCalled); + + platform.handleTool("get_products", Map.of(), ctx); + assertTrue(platform.getProductsCalled); + assertFalse(platform.listAccountsCalled); + + platform.handleTool("list_accounts", Map.of(), ctx); + assertTrue(platform.listAccountsCalled); + } + + /** + * Minimal stub transport that satisfies the non-null requirement for + * {@link AdcpServerBuilder#build()} without starting a real server. + * The MCP server builder calls no methods on the transport during build(). + */ + private static class StubMcpTransport implements McpServerTransportProvider { + @Override + public void setSessionFactory( + io.modelcontextprotocol.spec.McpServerSession.Factory factory) { + // no-op + } + + @Override + public reactor.core.publisher.Mono notifyClients( + String method, + Object params) { + return reactor.core.publisher.Mono.empty(); + } + + @Override + public reactor.core.publisher.Mono closeGracefully() { + return reactor.core.publisher.Mono.empty(); + } + } +} diff --git a/adcp/build.gradle.kts b/adcp/build.gradle.kts index 8a1ae54..6550046 100644 --- a/adcp/build.gradle.kts +++ b/adcp/build.gradle.kts @@ -14,4 +14,72 @@ dependencies { api(libs.slf4j.api) api(libs.jspecify) implementation(libs.json.schema.validator) + // MCP SDK client transport — needed for McpClient, StreamableHTTP, SSE fallback. + // Same artifacts as adcp-server; here they provide the caller/client side. + // Exclude json-schema-validator from mcp-json-jackson2 to keep our pinned 1.5.x. + implementation(libs.mcp.core) + implementation(libs.mcp.json.jackson2) { + exclude(group = "com.networknt", module = "json-schema-validator") + } +} + +// -- Build-time SDK version constant ---------------------------------------- +// Reads ADCP_VERSION (e.g. "3.0.11") and generates AdcpSdkVersion.java with +// the major and release-precision (major.minor) constants. This lets callers +// do cross-major validation at config time without hardcoding a version number. +// Output lands in build/generated/ and is NOT checked in. + +val generateSdkVersion = tasks.register("generateSdkVersion") { + val versionFile = rootProject.file("ADCP_VERSION") + inputs.file(versionFile) + val outputDir = layout.buildDirectory.dir("generated/sources/sdk-version/main/java") + outputs.dir(outputDir) + + doLast { + val raw = versionFile.readText().trim() + val parts = raw.split(".") + require(parts.size >= 2) { "ADCP_VERSION must be in major.minor.patch format: $raw" } + val major = parts[0].toInt() + val release = "${parts[0]}.${parts[1]}" // release-precision, e.g. "3.0" + + val pkg = "org.adcontextprotocol.adcp" + val dir = outputDir.get().asFile.resolve(pkg.replace('.', '/')) + dir.mkdirs() + dir.resolve("AdcpSdkVersion.java").writeText( + """ + package $pkg; + + /** + * Build-time AdCP SDK version constants — generated from {@code ADCP_VERSION}. + * Do not edit manually; update {@code ADCP_VERSION} at the repo root instead. + * + *

Used for cross-major validation: if a caller pins + * {@code adcpVersion("X.Y")} and {@code X != SDK_MAJOR_VERSION}, + * a {@link org.adcontextprotocol.adcp.error.ConfigurationError} is thrown + * before any network request is made. + */ + public final class AdcpSdkVersion { + + private AdcpSdkVersion() {} + + /** Major protocol version this SDK was built for (e.g. {@code 3}). */ + public static final int SDK_MAJOR_VERSION = $major; + + /** + * Release-precision protocol version this SDK was built for + * (e.g. {@code "3.0"}). + */ + public static final String SDK_RELEASE_VERSION = "$release"; + } + """.trimIndent() + ) + } +} + +sourceSets.named("main") { + java.srcDir(generateSdkVersion.map { it.outputs.files.singleFile }) +} + +tasks.named("compileJava") { + dependsOn(generateSdkVersion) } diff --git a/adcp/gradle.lockfile b/adcp/gradle.lockfile index 42911fd..afd7482 100644 --- a/adcp/gradle.lockfile +++ b/adcp/gradle.lockfile @@ -2,13 +2,16 @@ # Manual edits can break the build and are not advised. # This file is expected to be part of source control. com.ethlo.time:itu:1.10.3=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -com.fasterxml.jackson.core:jackson-annotations:2.18.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -com.fasterxml.jackson.core:jackson-core:2.18.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -com.fasterxml.jackson.core:jackson-databind:2.18.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:2.18.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.18.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -com.fasterxml.jackson:jackson-bom:2.18.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath +com.fasterxml.jackson.core:jackson-annotations:2.20=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath +com.fasterxml.jackson.core:jackson-core:2.20.1=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath +com.fasterxml.jackson.core:jackson-databind:2.20.1=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath +com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:2.20.1=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath +com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.20.1=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath +com.fasterxml.jackson:jackson-bom:2.20.1=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath com.networknt:json-schema-validator:1.5.6=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath +io.modelcontextprotocol.sdk:mcp-core:1.1.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath +io.modelcontextprotocol.sdk:mcp-json-jackson2:1.1.2=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath +io.projectreactor:reactor-core:3.7.0=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath org.apiguardian:apiguardian-api:1.1.2=testCompileClasspath org.jspecify:jspecify:1.0.0=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath org.junit.jupiter:junit-jupiter-api:5.11.4=testCompileClasspath,testRuntimeClasspath @@ -19,6 +22,7 @@ org.junit.platform:junit-platform-engine:1.11.4=testRuntimeClasspath org.junit.platform:junit-platform-launcher:1.11.4=testRuntimeClasspath org.junit:junit-bom:5.11.4=testCompileClasspath,testRuntimeClasspath org.opentest4j:opentest4j:1.3.0=testCompileClasspath,testRuntimeClasspath +org.reactivestreams:reactive-streams:1.0.4=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath org.slf4j:slf4j-api:2.0.16=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath -org.yaml:snakeyaml:2.3=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath +org.yaml:snakeyaml:2.4=compileClasspath,runtimeClasspath,testCompileClasspath,testRuntimeClasspath empty=annotationProcessor,testAnnotationProcessor diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/AdcpClient.java b/adcp/src/main/java/org/adcontextprotocol/adcp/AdcpClient.java new file mode 100644 index 0000000..3a9cf19 --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/AdcpClient.java @@ -0,0 +1,231 @@ +package org.adcontextprotocol.adcp; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.adcontextprotocol.adcp.error.ConfigurationError; +import org.adcontextprotocol.adcp.http.AdcpHttpClient; +import org.adcontextprotocol.adcp.http.SsrfPolicy; +import org.adcontextprotocol.adcp.schema.AdcpObjectMapperFactory; +import org.adcontextprotocol.adcp.transport.CallToolOptions; +import org.adcontextprotocol.adcp.transport.ProtocolClient; +import org.adcontextprotocol.adcp.transport.mcp.McpConnectionManager; +import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.util.Map; +import java.util.Objects; + +/** + * The main user-facing AdCP client. Single-agent, all tool methods + * funnel through {@link ProtocolClient#callTool}. + * + *

Usage: + *

{@code
+ * try (AdcpClient client = AdcpClient.builder()
+ *         .agent(AgentConfig.mcp("seller", URI.create("https://agent.example.com")))
+ *         .build()) {
+ *     var resp = client.callTool("get_products", args, GetProductsResponse.class);
+ * }
+ * }
+ * + *

Named convenience methods (e.g. {@code getProducts()}) are provided + * for each tool in the TS SDK parity target. All funnel through + * {@link #callTool(String, Map, Class, CallToolOptions)}. + */ +public final class AdcpClient implements AutoCloseable { + + private static final Logger log = LoggerFactory.getLogger(AdcpClient.class); + + private final AgentConfig agent; + private final ProtocolClient protocolClient; + private final AdcpHttpClient adcpHttpClient; + private final ObjectMapper objectMapper; + private final @Nullable AdcpVersion adcpVersion; + + private AdcpClient(Builder builder) { + if (builder.agent == null) { + throw new ConfigurationError("AdcpClient.agent is required", "agent"); + } + this.agent = builder.agent; + this.adcpVersion = builder.adcpVersion; + + this.objectMapper = builder.objectMapper != null + ? builder.objectMapper + : AdcpObjectMapperFactory.create(); + + SsrfPolicy ssrfPolicy = builder.ssrfPolicy != null + ? builder.ssrfPolicy + : SsrfPolicy.strict(); + + this.adcpHttpClient = AdcpHttpClient.builder() + .ssrfPolicy(ssrfPolicy) + .build(); + McpConnectionManager connectionManager = new McpConnectionManager( + Duration.ofSeconds(10), builder.requestTimeout, adcpHttpClient); + this.protocolClient = new ProtocolClient( + this.objectMapper, ssrfPolicy, adcpVersion, connectionManager); + } + + /** Creates a new builder. */ + public static Builder builder() { + return new Builder(); + } + + // -- Generic tool call -- + + /** + * Calls a tool with explicit options. + * + * @param toolName the MCP tool name (e.g. "get_products") + * @param args tool arguments + * @param args tool arguments (may be {@code null}, treated as empty) + * @param responseType expected response type + * @param options call options + * @param response type + * @return deserialized response + */ + public T callTool(String toolName, @Nullable Map args, + Class responseType, CallToolOptions options) { + return protocolClient.callTool(agent, toolName, + args != null ? args : Map.of(), responseType, options); + } + + /** + * Calls a tool with default options. + */ + public T callTool(String toolName, @Nullable Map args, + Class responseType) { + return callTool(toolName, args, responseType, CallToolOptions.DEFAULT); + } + + // -- Named convenience methods (TS SDK parity) -- + // Each converts a typed request to a Map and delegates to callTool. + + /** + * Converts a request object to a Map for the tool call. + * Uses the ObjectMapper to handle the conversion. + */ + @SuppressWarnings("unchecked") + private Map toArgs(Object request) { + return objectMapper.convertValue(request, Map.class); + } + + /** + * Calls a named tool with a typed request object. + * + * @param toolName the MCP tool name + * @param request the typed request object (will be serialized to Map) + * @param responseType expected response type + * @param response type + * @return deserialized response + */ + public T callNamedTool(String toolName, Object request, + Class responseType) { + return callTool(toolName, toArgs(request), responseType); + } + + // -- Lifecycle -- + + /** Returns the agent config this client is bound to. */ + public AgentConfig agent() { + return agent; + } + + /** Returns the protocol version in use. */ + public @Nullable AdcpVersion adcpVersion() { + return adcpVersion; + } + + @Override + public void close() { + try { + protocolClient.close(); + } finally { + adcpHttpClient.close(); + } + } + + // -- Builder -- + + public static final class Builder { + private @Nullable AgentConfig agent; + private @Nullable AdcpVersion adcpVersion; + private @Nullable ObjectMapper objectMapper; + private @Nullable SsrfPolicy ssrfPolicy; + private Duration requestTimeout = Duration.ofSeconds(30); + + private Builder() {} + + /** Required: the agent to connect to. */ + public Builder agent(AgentConfig agent) { + this.agent = Objects.requireNonNull(agent); + return this; + } + + /** Pin a specific AdCP protocol version. */ + public Builder adcpVersion(AdcpVersion adcpVersion) { + this.adcpVersion = adcpVersion; + return this; + } + + /** + * Pin a specific AdCP protocol version by release-precision string + * (e.g. {@code "3.0"}, {@code "3.1"}). + * + *

Equivalent to {@code adcpVersion(AdcpVersion.of(releaseVersion))}. + * Throws {@link org.adcontextprotocol.adcp.error.ConfigurationError} at + * {@link #build()} time if the major version does not match the SDK. + */ + public Builder adcpVersion(String releaseVersion) { + return adcpVersion(AdcpVersion.of(releaseVersion)); + } + + /** Override the Jackson ObjectMapper. */ + public Builder objectMapper(ObjectMapper objectMapper) { + this.objectMapper = Objects.requireNonNull(objectMapper); + return this; + } + + /** + * Override the SSRF policy. Defaults to {@link SsrfPolicy#strict()}. + * Use {@link SsrfPolicy#permissive()} for local development only. + */ + public Builder ssrfPolicy(SsrfPolicy ssrfPolicy) { + this.ssrfPolicy = Objects.requireNonNull(ssrfPolicy); + return this; + } + + /** + * Override the per-request timeout for MCP tool calls. Defaults to 30 seconds. + * Increase this for agents that perform long-running operations. + */ + public Builder requestTimeout(Duration requestTimeout) { + this.requestTimeout = Objects.requireNonNull(requestTimeout, "requestTimeout"); + return this; + } + + /** Builds the client. */ + public AdcpClient build() { + validateAdcpVersion(adcpVersion); + return new AdcpClient(this); + } + + /** + * Validates that the pinned version's major matches the SDK's built-in major. + * Cross-major pins (e.g. requesting "2.0" from a major-3 SDK) fail fast before + * any network request. + */ + private static void validateAdcpVersion(@Nullable AdcpVersion version) { + if (version == null) return; + if (version.majorVersion() != AdcpSdkVersion.SDK_MAJOR_VERSION) { + throw new ConfigurationError( + "adcpVersion major " + version.majorVersion() + + " does not match SDK major " + + AdcpSdkVersion.SDK_MAJOR_VERSION + + " (built for AdCP " + AdcpSdkVersion.SDK_RELEASE_VERSION + ")", + "adcpVersion"); + } + } + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/AdcpVersion.java b/adcp/src/main/java/org/adcontextprotocol/adcp/AdcpVersion.java new file mode 100644 index 0000000..9ab8bf0 --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/AdcpVersion.java @@ -0,0 +1,69 @@ +package org.adcontextprotocol.adcp; + +import org.jspecify.annotations.Nullable; + +/** + * AdCP protocol version identifier. + * + *

Used by the version envelope injected into every tool call. + * The {@link #majorVersion()} is always present; the {@link #minorVersion()} + * is set only when a specific minor version is required. + * + * @param majorVersion the major protocol version (e.g. 3) + * @param minorVersion optional minor version string (e.g. "3.1"), or {@code null} + */ +public record AdcpVersion(int majorVersion, @Nullable String minorVersion) { + + private static final java.util.regex.Pattern MINOR_VERSION_PATTERN = + java.util.regex.Pattern.compile("\\d+\\.\\d+(\\.\\d+)?"); + + /** AdCP v3.0 (current default). */ + public static final AdcpVersion V3 = new AdcpVersion(3, null); + + /** AdCP v3.1. */ + public static final AdcpVersion V3_1 = new AdcpVersion(3, "3.1"); + + public AdcpVersion { + if (majorVersion < 1) { + throw new IllegalArgumentException("majorVersion must be >= 1: " + majorVersion); + } + if (minorVersion != null) { + if (minorVersion.length() > 20) { + throw new IllegalArgumentException( + "minorVersion too long: " + minorVersion.length()); + } + if (!MINOR_VERSION_PATTERN.matcher(minorVersion).matches()) { + throw new IllegalArgumentException( + "minorVersion must be a version string (e.g. '3.1'): " + + minorVersion); + } + if (!minorVersion.startsWith(majorVersion + ".")) { + throw new IllegalArgumentException( + "minorVersion must start with majorVersion: " + minorVersion); + } + } + } + + /** + * Parses a release-precision version string (e.g. {@code "3.0"}, {@code "3.1"}) + * into an {@code AdcpVersion}. + * + *

This is the string-based convenience factory — pass the same value you + * would set in the Python SDK or TS SDK {@code adcpVersion} constructor option. + * + * @param releaseVersion release-precision version (major.minor, e.g. {@code "3.0"}) + * @return parsed {@code AdcpVersion} + * @throws IllegalArgumentException if the string is not in major.minor format + */ + public static AdcpVersion of(String releaseVersion) { + java.util.Objects.requireNonNull(releaseVersion, "releaseVersion"); + if (!MINOR_VERSION_PATTERN.matcher(releaseVersion).matches()) { + throw new IllegalArgumentException( + "releaseVersion must be in major.minor format (e.g. '3.0'): " + + releaseVersion); + } + int dotIndex = releaseVersion.indexOf('.'); + int major = Integer.parseInt(releaseVersion.substring(0, dotIndex)); + return new AdcpVersion(major, releaseVersion); + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/AgentConfig.java b/adcp/src/main/java/org/adcontextprotocol/adcp/AgentConfig.java new file mode 100644 index 0000000..0edd865 --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/AgentConfig.java @@ -0,0 +1,264 @@ +package org.adcontextprotocol.adcp; + +import org.adcontextprotocol.adcp.auth.BasicCredentials; +import org.adcontextprotocol.adcp.auth.OAuthClientCredentials; +import org.adcontextprotocol.adcp.auth.OAuthTokens; +import org.adcontextprotocol.adcp.error.ConfigurationError; +import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +/** + * Configuration for connecting to an AdCP agent. + * + *

Carries the agent URI, transport protocol, auth credentials, + * and optional settings like request signing and webhook configuration. + * + *

Auth is mutually exclusive: + *

    + *
  • Static Bearer token ({@code authToken})
  • + *
  • HTTP Basic ({@code basicAuth})
  • + *
  • OAuth client-credentials ({@code oauthClientCredentials})
  • + *
  • OAuth auth-code ({@code oauthTokens})
  • + *
+ * + * Use {@link #builder()} to construct instances. + */ +public record AgentConfig( + String id, + URI agentUri, + Protocol protocol, + @Nullable String authToken, + @Nullable BasicCredentials basicAuth, + @Nullable OAuthClientCredentials oauthClientCredentials, + @Nullable OAuthTokens oauthTokens, + @Nullable String webhookUrlTemplate, + @Nullable String webhookSecret, + @Nullable AdcpVersion adcpVersion, + Map extraHeaders +) { + + private static final Logger log = LoggerFactory.getLogger(AgentConfig.class); + + public AgentConfig { + Objects.requireNonNull(id, "id"); + Objects.requireNonNull(agentUri, "agentUri"); + Objects.requireNonNull(protocol, "protocol"); + extraHeaders = Map.copyOf(extraHeaders); + validateAuth(authToken, basicAuth, oauthClientCredentials, oauthTokens); + validateAuthToken(authToken); + validateExtraHeaders(extraHeaders); + warnPlaintextAuth(agentUri, authToken, basicAuth, oauthClientCredentials, oauthTokens); + } + + @Override + public String toString() { + return "AgentConfig[id=" + id + + ", agentUri=" + agentUri + + ", protocol=" + protocol + + ", authToken=" + (authToken != null ? "" : "null") + + ", basicAuth=" + (basicAuth != null ? "" : "null") + + ", oauthClientCredentials=" + (oauthClientCredentials != null ? "" : "null") + + ", oauthTokens=" + (oauthTokens != null ? "" : "null") + + ", webhookUrlTemplate=" + webhookUrlTemplate + + ", webhookSecret=" + (webhookSecret != null ? "" : "null") + + ", adcpVersion=" + adcpVersion + + ", extraHeaders=" + (extraHeaders.isEmpty() + ? "{}" : "<" + extraHeaders.size() + " headers>") + "]"; + } + + /** Creates a builder for {@code AgentConfig}. */ + public static Builder builder() { + return new Builder(); + } + + /** Shorthand: creates a minimal MCP agent config with no auth. */ + public static AgentConfig mcp(String id, URI agentUri) { + return builder().id(id).agentUri(agentUri).protocol(Protocol.MCP).build(); + } + + /** Shorthand: creates an MCP agent config with a static Bearer token. */ + public static AgentConfig mcp(String id, URI agentUri, String authToken) { + return builder() + .id(id) + .agentUri(agentUri) + .protocol(Protocol.MCP) + .authToken(authToken) + .build(); + } + + private static void validateAuth( + @Nullable String authToken, + @Nullable BasicCredentials basicAuth, + @Nullable OAuthClientCredentials oauthCC, + @Nullable OAuthTokens oauthTokens) { + + int count = 0; + if (authToken != null) count++; + if (basicAuth != null) count++; + if (oauthCC != null) count++; + if (oauthTokens != null) count++; + + if (count > 1) { + throw new ConfigurationError( + "Only one auth mechanism may be set on an AgentConfig " + + "(got " + count + ": authToken=" + + (authToken != null) + ", basicAuth=" + + (basicAuth != null) + ", oauthClientCredentials=" + + (oauthCC != null) + ", oauthTokens=" + + (oauthTokens != null) + ")", + "auth"); + } + } + + private static void validateAuthToken(@Nullable String authToken) { + if (authToken != null + && (authToken.indexOf('\r') >= 0 || authToken.indexOf('\n') >= 0)) { + throw new ConfigurationError( + "authToken must not contain CR/LF characters", "authToken"); + } + } + + private static void validateExtraHeaders(Map headers) { + for (var entry : headers.entrySet()) { + if (hasCrlf(entry.getKey()) || hasCrlf(entry.getValue())) { + throw new ConfigurationError( + "extraHeaders key/value must not contain CR/LF: " + + entry.getKey(), "extraHeaders"); + } + } + } + + private static boolean hasCrlf(String s) { + return s.indexOf('\r') >= 0 || s.indexOf('\n') >= 0; + } + + private static void warnPlaintextAuth( + URI agentUri, + @Nullable String authToken, + @Nullable BasicCredentials basicAuth, + @Nullable OAuthClientCredentials oauthCC, + @Nullable OAuthTokens oauthTokens) { + boolean hasAuth = authToken != null || basicAuth != null + || oauthCC != null || oauthTokens != null; + if (hasAuth && "http".equalsIgnoreCase(agentUri.getScheme())) { + log.warn("Credentials configured for plaintext HTTP agent URI: {}. " + + "Use HTTPS in production to prevent credential interception.", + agentUri); + } + } + + // -- Builder -- + + public static final class Builder { + private @Nullable String id; + private @Nullable URI agentUri; + private Protocol protocol = Protocol.MCP; + private @Nullable String authToken; + private @Nullable BasicCredentials basicAuth; + private @Nullable OAuthClientCredentials oauthClientCredentials; + private @Nullable OAuthTokens oauthTokens; + private @Nullable String webhookUrlTemplate; + private @Nullable String webhookSecret; + private @Nullable AdcpVersion adcpVersion; + private Map extraHeaders = Map.of(); + + private Builder() {} + + /** Required: unique identifier for this agent. */ + public Builder id(String id) { + this.id = Objects.requireNonNull(id); + return this; + } + + /** Required: the agent's base URI. */ + public Builder agentUri(URI agentUri) { + this.agentUri = Objects.requireNonNull(agentUri); + return this; + } + + /** Transport protocol. Defaults to {@link Protocol#MCP}. */ + public Builder protocol(Protocol protocol) { + this.protocol = Objects.requireNonNull(protocol); + return this; + } + + /** Static Bearer token. Mutually exclusive with other auth. */ + public Builder authToken(@Nullable String authToken) { + this.authToken = authToken; + return this; + } + + /** HTTP Basic credentials. Mutually exclusive with other auth. */ + public Builder basicAuth(@Nullable BasicCredentials basicAuth) { + this.basicAuth = basicAuth; + return this; + } + + /** OAuth client-credentials config. Mutually exclusive with other auth. */ + public Builder oauthClientCredentials(@Nullable OAuthClientCredentials oauthCC) { + this.oauthClientCredentials = oauthCC; + return this; + } + + /** OAuth auth-code tokens. Mutually exclusive with other auth. */ + public Builder oauthTokens(@Nullable OAuthTokens oauthTokens) { + this.oauthTokens = oauthTokens; + return this; + } + + /** Webhook URL template for async task results. */ + public Builder webhookUrlTemplate(@Nullable String webhookUrlTemplate) { + this.webhookUrlTemplate = webhookUrlTemplate; + return this; + } + + /** HMAC-SHA256 secret for webhook verification. */ + public Builder webhookSecret(@Nullable String webhookSecret) { + this.webhookSecret = webhookSecret; + return this; + } + + /** Pin a specific AdCP protocol version. */ + public Builder adcpVersion(@Nullable AdcpVersion adcpVersion) { + this.adcpVersion = adcpVersion; + return this; + } + + /** + * Pin a specific AdCP protocol version by release-precision string + * (e.g. {@code "3.0"}, {@code "3.1"}). + * + *

Equivalent to {@code adcpVersion(AdcpVersion.of(releaseVersion))}. + * Cross-major pins are rejected at {@link AdcpClient} build time. + */ + public Builder adcpVersion(String releaseVersion) { + return adcpVersion(AdcpVersion.of(releaseVersion)); + } + + /** Extra headers injected into every request to this agent. */ + public Builder extraHeaders(Map extraHeaders) { + this.extraHeaders = Map.copyOf(extraHeaders); + return this; + } + + /** Builds the config, validating required fields and auth exclusivity. */ + public AgentConfig build() { + if (id == null) { + throw new ConfigurationError("AgentConfig.id is required", "id"); + } + if (agentUri == null) { + throw new ConfigurationError("AgentConfig.agentUri is required", "agentUri"); + } + return new AgentConfig( + id, agentUri, protocol, + authToken, basicAuth, oauthClientCredentials, oauthTokens, + webhookUrlTemplate, webhookSecret, adcpVersion, + extraHeaders); + } + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/Protocol.java b/adcp/src/main/java/org/adcontextprotocol/adcp/Protocol.java new file mode 100644 index 0000000..44e1918 --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/Protocol.java @@ -0,0 +1,19 @@ +package org.adcontextprotocol.adcp; + +/** + * Transport protocol used to communicate with an agent. + * + *

Determines which dispatch path {@code ProtocolClient} uses: + *

    + *
  • {@link #MCP} — Model Context Protocol (StreamableHTTP + SSE fallback)
  • + *
  • {@link #A2A} — Agent-to-Agent protocol (JSON-RPC 2.0 + SSE streaming)
  • + *
+ */ +public enum Protocol { + + /** Model Context Protocol — the primary transport for AdCP. */ + MCP, + + /** Agent-to-Agent protocol (v0.4). */ + A2A +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/auth/AuthChallengeInfo.java b/adcp/src/main/java/org/adcontextprotocol/adcp/auth/AuthChallengeInfo.java new file mode 100644 index 0000000..045e5b2 --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/auth/AuthChallengeInfo.java @@ -0,0 +1,31 @@ +package org.adcontextprotocol.adcp.auth; + +import org.jspecify.annotations.Nullable; + +/** + * Parsed {@code WWW-Authenticate} challenge from an HTTP 401 response. + * + *

Fields follow RFC 9110 §11.6.1. The {@link #scheme()} is always + * lowercased for case-insensitive comparison. + * + * @param scheme auth scheme, lowercased (e.g. "bearer", "basic") + * @param realm the protection realm, if present + * @param scope OAuth scope, if present + * @param error OAuth error code, if present + * @param errorDescription human-readable error description, if present + */ +public record AuthChallengeInfo( + String scheme, + @Nullable String realm, + @Nullable String scope, + @Nullable String error, + @Nullable String errorDescription +) { + public AuthChallengeInfo { + java.util.Objects.requireNonNull(scheme, "scheme"); + scheme = scheme.toLowerCase(java.util.Locale.ROOT); + if (scheme.isBlank()) { + throw new IllegalArgumentException("scheme must not be blank"); + } + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/auth/AuthTokenResolver.java b/adcp/src/main/java/org/adcontextprotocol/adcp/auth/AuthTokenResolver.java new file mode 100644 index 0000000..8ce8ab0 --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/auth/AuthTokenResolver.java @@ -0,0 +1,66 @@ +package org.adcontextprotocol.adcp.auth; + +import org.adcontextprotocol.adcp.AgentConfig; +import org.adcontextprotocol.adcp.error.FeatureUnsupportedError; + +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +/** + * Resolves auth headers from an {@link AgentConfig}. + * + *

Supports: + *

    + *
  • Static Bearer token → {@code Authorization: Bearer} + {@code x-adcp-auth}
  • + *
  • HTTP Basic → {@code Authorization: Basic}
  • + *
  • OAuth auth-code tokens → {@code Authorization: Bearer}
  • + *
  • No auth → empty map
  • + *
+ * + *

OAuth client-credentials flow (token exchange/refresh) is handled + * separately by the transport layer before calling this resolver. + */ +public final class AuthTokenResolver { + + private AuthTokenResolver() {} + + /** + * Resolves the auth headers for the given agent config. + * + * @param config the agent configuration + * @return map of header name → value (may be empty) + */ + public static Map resolve(AgentConfig config) { + Map headers = new LinkedHashMap<>(); + + if (config.authToken() != null) { + // Static bearer token — send both headers for backward compat + headers.put("Authorization", "Bearer " + config.authToken()); + headers.put("x-adcp-auth", config.authToken()); + } else if (config.basicAuth() != null) { + // HTTP Basic (7.2.0 delta) + BasicCredentials creds = config.basicAuth(); + String encoded = Base64.getEncoder().encodeToString( + (creds.username() + ":" + creds.password()) + .getBytes(StandardCharsets.UTF_8)); + headers.put("Authorization", "Basic " + encoded); + } else if (config.oauthTokens() != null) { + // OAuth auth-code tokens + headers.put("Authorization", "Bearer " + config.oauthTokens().accessToken()); + } else if (config.oauthClientCredentials() != null) { + // OAuth client-credentials flow is not yet implemented. + // Throw explicitly so callers know (rather than silently + // sending an unauthenticated request that fails with 401). + throw new FeatureUnsupportedError( + List.of("OAuth client-credentials token exchange"), + List.of("Static bearer token (authToken)", + "HTTP Basic auth (basicAuth)", + "OAuth auth-code tokens (oauthTokens)")); + } + + return Map.copyOf(headers); + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/auth/BasicCredentials.java b/adcp/src/main/java/org/adcontextprotocol/adcp/auth/BasicCredentials.java new file mode 100644 index 0000000..60e290e --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/auth/BasicCredentials.java @@ -0,0 +1,33 @@ +package org.adcontextprotocol.adcp.auth; + +import java.util.Objects; + +/** + * HTTP Basic credentials (RFC 7617). + * + *

Validated at construction: neither {@code username} nor + * {@code password} may be blank. + * + * @param username the username + * @param password the password + */ +public record BasicCredentials(String username, String password) { + + public BasicCredentials { + Objects.requireNonNull(username, "username"); + Objects.requireNonNull(password, "password"); + if (username.isBlank()) { + throw new IllegalArgumentException("username must not be blank"); + } + if (username.contains(":")) { + throw new IllegalArgumentException("username must not contain ':' (RFC 7617 §2)"); + } + // Blank passwords are allowed — many platforms use the + // username=token, password="" pattern (e.g. GitHub PATs, Stripe). + } + + @Override + public String toString() { + return "BasicCredentials[username=" + username + ", password=]"; + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/auth/OAuthClientCredentials.java b/adcp/src/main/java/org/adcontextprotocol/adcp/auth/OAuthClientCredentials.java new file mode 100644 index 0000000..743f01f --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/auth/OAuthClientCredentials.java @@ -0,0 +1,40 @@ +package org.adcontextprotocol.adcp.auth; + +import org.jspecify.annotations.Nullable; + +import java.util.Objects; + +/** + * OAuth 2.0 client credentials for the client-credentials grant flow. + * + * @param clientId the OAuth client ID + * @param clientSecret the OAuth client secret + * @param tokenEndpoint the token endpoint URI + * @param scope optional scope string + */ +public record OAuthClientCredentials( + String clientId, + String clientSecret, + String tokenEndpoint, + @Nullable String scope +) { + + public OAuthClientCredentials { + Objects.requireNonNull(clientId, "clientId"); + Objects.requireNonNull(clientSecret, "clientSecret"); + Objects.requireNonNull(tokenEndpoint, "tokenEndpoint"); + if (clientId.isBlank()) { + throw new IllegalArgumentException("clientId must not be blank"); + } + if (clientSecret.isBlank()) { + throw new IllegalArgumentException("clientSecret must not be blank"); + } + } + + @Override + public String toString() { + return "OAuthClientCredentials[clientId=" + clientId + + ", clientSecret=, tokenEndpoint=" + tokenEndpoint + + ", scope=" + scope + "]"; + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/auth/OAuthMetadataInfo.java b/adcp/src/main/java/org/adcontextprotocol/adcp/auth/OAuthMetadataInfo.java new file mode 100644 index 0000000..9f30ed8 --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/auth/OAuthMetadataInfo.java @@ -0,0 +1,24 @@ +package org.adcontextprotocol.adcp.auth; + +import org.jspecify.annotations.Nullable; + +/** + * OAuth metadata discovered from an agent, typically via RFC 9728 + * Protected Resource Metadata or from the MCP OAuth flow. + * + * @param authorizationEndpoint the OAuth authorization endpoint + * @param tokenEndpoint the OAuth token endpoint + * @param registrationEndpoint optional dynamic client registration endpoint + * @param issuer optional OAuth issuer identifier + */ +public record OAuthMetadataInfo( + String authorizationEndpoint, + String tokenEndpoint, + @Nullable String registrationEndpoint, + @Nullable String issuer +) { + public OAuthMetadataInfo { + java.util.Objects.requireNonNull(authorizationEndpoint, "authorizationEndpoint"); + java.util.Objects.requireNonNull(tokenEndpoint, "tokenEndpoint"); + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/auth/OAuthTokens.java b/adcp/src/main/java/org/adcontextprotocol/adcp/auth/OAuthTokens.java new file mode 100644 index 0000000..739f132 --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/auth/OAuthTokens.java @@ -0,0 +1,60 @@ +package org.adcontextprotocol.adcp.auth; + +import org.jspecify.annotations.Nullable; + +import java.time.Instant; +import java.util.Objects; + +/** + * OAuth 2.0 tokens obtained from an auth-code or refresh grant. + * + * @param accessToken the access token + * @param refreshToken the refresh token (may be {@code null} for CC grants) + * @param expiresAt when the access token expires (may be {@code null}) + * @param tokenType token type (typically "Bearer") + */ +public record OAuthTokens( + String accessToken, + @Nullable String refreshToken, + @Nullable Instant expiresAt, + String tokenType +) { + + public OAuthTokens { + Objects.requireNonNull(accessToken, "accessToken"); + Objects.requireNonNull(tokenType, "tokenType"); + if (accessToken.isBlank()) { + throw new IllegalArgumentException("accessToken must not be blank"); + } + if (accessToken.indexOf('\r') >= 0 || accessToken.indexOf('\n') >= 0) { + throw new IllegalArgumentException("accessToken must not contain CR/LF characters"); + } + } + + /** Creates a Bearer token with the given access token. */ + public static OAuthTokens bearer(String accessToken) { + return new OAuthTokens(accessToken, null, null, "Bearer"); + } + + /** Creates a Bearer token with refresh token. */ + public static OAuthTokens bearer(String accessToken, @Nullable String refreshToken, + @Nullable Instant expiresAt) { + return new OAuthTokens(accessToken, refreshToken, expiresAt, "Bearer"); + } + + /** Whether the access token has expired (with 30s safety margin). */ + public boolean isExpired() { + if (expiresAt == null) { + return false; + } + return Instant.now().plusSeconds(30).isAfter(expiresAt); + } + + @Override + public String toString() { + return "OAuthTokens[accessToken=, refreshToken=" + + (refreshToken != null ? "" : "null") + + ", expiresAt=" + expiresAt + + ", tokenType=" + tokenType + "]"; + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/auth/WwwAuthenticateParser.java b/adcp/src/main/java/org/adcontextprotocol/adcp/auth/WwwAuthenticateParser.java new file mode 100644 index 0000000..d85d509 --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/auth/WwwAuthenticateParser.java @@ -0,0 +1,96 @@ +package org.adcontextprotocol.adcp.auth; + +import org.jspecify.annotations.Nullable; + +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Parses {@code WWW-Authenticate} headers per RFC 9110 §11.6.1. + * + *

Handles the common cases seen in AdCP agents: + *

    + *
  • {@code Bearer realm="example", error="invalid_token"}
  • + *
  • {@code Basic realm="Agent"}
  • + *
  • {@code Bearer} (no parameters)
  • + *
+ * + *

Only parses the first challenge in multi-challenge headers. + * Quoted-pair escapes ({@code \"}) inside quoted strings are handled + * per RFC 9110 §5.6.4. + */ +public final class WwwAuthenticateParser { + + // Matches: scheme followed by optional key=value pairs + private static final Pattern SCHEME_PATTERN = + Pattern.compile("^(\\S+)\\s*(.*)$"); + + // Matches: key="value" (with quoted-pair support) or key=token + // Group 1: key + // Group 2: quoted string content (may contain escaped chars) + // Group 3: unquoted token value + private static final Pattern PARAM_PATTERN = + Pattern.compile("(\\w+)\\s*=\\s*(?:\"((?:[^\"\\\\]|\\\\.)*)\"|([^\\s,]+))"); + + /** Maximum number of parameters to parse (DoS guard). */ + private static final int MAX_PARAMS = 16; + + private WwwAuthenticateParser() {} + + /** + * Parses a {@code WWW-Authenticate} header value into an + * {@link AuthChallengeInfo}. + * + * @param header the header value (e.g. {@code "Bearer realm=\"example\""}) + * @return parsed challenge info, or {@code null} if the header is blank + */ + public static @Nullable AuthChallengeInfo parse(@Nullable String header) { + if (header == null || header.isBlank()) { + return null; + } + + Matcher schemeMatcher = SCHEME_PATTERN.matcher(header.trim()); + if (!schemeMatcher.matches()) { + return null; + } + + String scheme = schemeMatcher.group(1).toLowerCase(java.util.Locale.ROOT); + String paramString = schemeMatcher.group(2); + + Map params = parseParams(paramString); + + return new AuthChallengeInfo( + scheme, + params.get("realm"), + params.get("scope"), + params.get("error"), + params.get("error_description")); + } + + private static Map parseParams(String paramString) { + Map params = new LinkedHashMap<>(); + if (paramString == null || paramString.isBlank()) { + return params; + } + + Matcher paramMatcher = PARAM_PATTERN.matcher(paramString); + int count = 0; + while (paramMatcher.find() && count < MAX_PARAMS) { + String key = paramMatcher.group(1).toLowerCase(java.util.Locale.ROOT); + String value; + if (paramMatcher.group(2) != null) { + // Quoted string — unescape quoted-pairs (RFC 9110 §5.6.4) + value = paramMatcher.group(2).replace("\\\"", "\"") + .replace("\\\\", "\\"); + } else { + value = paramMatcher.group(3); + } + params.put(key, value); + count++; + } + + return params; + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/auth/package-info.java b/adcp/src/main/java/org/adcontextprotocol/adcp/auth/package-info.java new file mode 100644 index 0000000..4ec40bd --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/auth/package-info.java @@ -0,0 +1,5 @@ +/** + * Authentication and authorization primitives for AdCP transport. + */ +@org.jspecify.annotations.NullMarked +package org.adcontextprotocol.adcp.auth; diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/error/AdcpError.java b/adcp/src/main/java/org/adcontextprotocol/adcp/error/AdcpError.java new file mode 100644 index 0000000..92d4095 --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/error/AdcpError.java @@ -0,0 +1,58 @@ +package org.adcontextprotocol.adcp.error; + +import org.jspecify.annotations.Nullable; + +/** + * Base class for all AdCP SDK errors. Each subclass carries a unique + * {@link #code()} string that callers can switch on. + * + *

Mirrors the TS SDK's {@code ADCPError} hierarchy. All AdCP errors + * are unchecked ({@link RuntimeException}) — callers who want to handle + * them catch specific subclasses. + */ +public abstract sealed class AdcpError extends RuntimeException + permits ProtocolError, + AuthenticationRequiredError, + TaskTimeoutError, + TaskAbortedError, + DeferredTaskError, + ValidationError, + ConfigurationError, + VersionUnsupportedError, + AgentNotFoundError, + UnsupportedTaskError, + FeatureUnsupportedError, + ResponseTooLargeError, + IdempotencyConflictError, + IdempotencyExpiredError { + + @java.io.Serial + private static final long serialVersionUID = 1L; + + private final String code; + @SuppressWarnings("serial") + private final @Nullable Object details; + + protected AdcpError(String code, String message, @Nullable Object details) { + super(message); + this.code = code; + this.details = details; + } + + protected AdcpError(String code, String message, @Nullable Object details, + @Nullable Throwable cause) { + super(message, cause); + this.code = code; + this.details = details; + } + + /** Stable error code for programmatic matching (e.g. "PROTOCOL_ERROR"). */ + public String code() { + return code; + } + + /** Optional structured details about the error. */ + public @Nullable Object details() { + return details; + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/error/AgentNotFoundError.java b/adcp/src/main/java/org/adcontextprotocol/adcp/error/AgentNotFoundError.java new file mode 100644 index 0000000..b100380 --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/error/AgentNotFoundError.java @@ -0,0 +1,31 @@ +package org.adcontextprotocol.adcp.error; + +import java.util.List; + +/** The requested agent was not found in the client configuration. */ +public final class AgentNotFoundError extends AdcpError { + + @java.io.Serial + private static final long serialVersionUID = 1L; + + private final String agentId; + @SuppressWarnings("serial") + private final List availableAgents; + + public AgentNotFoundError(String agentId, List availableAgents) { + super("AGENT_NOT_FOUND", + "Agent not found: " + agentId + + ". Available: " + availableAgents, + null); + this.agentId = agentId; + this.availableAgents = List.copyOf(availableAgents); + } + + public String agentId() { + return agentId; + } + + public List availableAgents() { + return availableAgents; + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/error/AuthenticationRequiredError.java b/adcp/src/main/java/org/adcontextprotocol/adcp/error/AuthenticationRequiredError.java new file mode 100644 index 0000000..a5f76ce --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/error/AuthenticationRequiredError.java @@ -0,0 +1,71 @@ +package org.adcontextprotocol.adcp.error; + +import org.adcontextprotocol.adcp.auth.AuthChallengeInfo; +import org.adcontextprotocol.adcp.auth.OAuthMetadataInfo; +import org.jspecify.annotations.Nullable; + +import java.net.URI; + +/** + * The agent requires authentication. Carries parsed {@code WWW-Authenticate} + * challenge info and optional OAuth metadata for programmatic auth flows. + * + *

The {@link #challenge()} field is populated on a best-effort basis via + * a HEAD (or OPTIONS) probe when a 401 is detected. It may be {@code null} + * even when authentication is genuinely required — for example, if the + * agent endpoint returns 405 for both HEAD and OPTIONS, or if the probe + * itself fails. Callers should not assume a {@code null} challenge means + * "no auth needed"; it means the auth scheme could not be determined. + */ +public final class AuthenticationRequiredError extends AdcpError { + + @java.io.Serial + private static final long serialVersionUID = 1L; + + private final URI agentUri; + @SuppressWarnings("serial") + private final @Nullable AuthChallengeInfo challenge; + @SuppressWarnings("serial") + private final @Nullable OAuthMetadataInfo oauthMetadata; + + public AuthenticationRequiredError( + URI agentUri, + @Nullable AuthChallengeInfo challenge, + @Nullable OAuthMetadataInfo oauthMetadata) { + this(agentUri, challenge, oauthMetadata, null); + } + + public AuthenticationRequiredError( + URI agentUri, + @Nullable AuthChallengeInfo challenge, + @Nullable OAuthMetadataInfo oauthMetadata, + @Nullable Throwable cause) { + super("AUTHENTICATION_REQUIRED", + "Authentication required for agent: " + agentUri, + null, cause); + this.agentUri = agentUri; + this.challenge = challenge; + this.oauthMetadata = oauthMetadata; + } + + public URI agentUri() { + return agentUri; + } + + public @Nullable AuthChallengeInfo challenge() { + return challenge; + } + + public @Nullable OAuthMetadataInfo oauthMetadata() { + return oauthMetadata; + } + + public boolean hasOAuth() { + return oauthMetadata != null; + } + + /** The suggested auth scheme (lowercased), or {@code null} if unknown. */ + public @Nullable String suggestedScheme() { + return challenge != null ? challenge.scheme() : null; + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/error/ConfigurationError.java b/adcp/src/main/java/org/adcontextprotocol/adcp/error/ConfigurationError.java new file mode 100644 index 0000000..3decb5b --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/error/ConfigurationError.java @@ -0,0 +1,21 @@ +package org.adcontextprotocol.adcp.error; + +import org.jspecify.annotations.Nullable; + +/** Client or agent configuration is invalid. */ +public final class ConfigurationError extends AdcpError { + + @java.io.Serial + private static final long serialVersionUID = 1L; + + private final @Nullable String configField; + + public ConfigurationError(String message, @Nullable String configField) { + super("CONFIGURATION_ERROR", message, null); + this.configField = configField; + } + + public @Nullable String configField() { + return configField; + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/error/DeferredTaskError.java b/adcp/src/main/java/org/adcontextprotocol/adcp/error/DeferredTaskError.java new file mode 100644 index 0000000..39818ca --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/error/DeferredTaskError.java @@ -0,0 +1,19 @@ +package org.adcontextprotocol.adcp.error; + +/** A task was deferred for async processing. Carries the deferral token. */ +public final class DeferredTaskError extends AdcpError { + + @java.io.Serial + private static final long serialVersionUID = 1L; + + private final String token; + + public DeferredTaskError(String token) { + super("TASK_DEFERRED", "Task deferred with token: " + token, null); + this.token = token; + } + + public String token() { + return token; + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/error/FeatureUnsupportedError.java b/adcp/src/main/java/org/adcontextprotocol/adcp/error/FeatureUnsupportedError.java new file mode 100644 index 0000000..d1e5e77 --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/error/FeatureUnsupportedError.java @@ -0,0 +1,33 @@ +package org.adcontextprotocol.adcp.error; + +import java.util.List; + +/** The agent lacks features required by the caller. */ +public final class FeatureUnsupportedError extends AdcpError { + + @java.io.Serial + private static final long serialVersionUID = 1L; + + @SuppressWarnings("serial") + private final List unsupportedFeatures; + @SuppressWarnings("serial") + private final List declaredFeatures; + + public FeatureUnsupportedError( + List unsupportedFeatures, + List declaredFeatures) { + super("FEATURE_UNSUPPORTED", + "Unsupported features: " + unsupportedFeatures, + null); + this.unsupportedFeatures = List.copyOf(unsupportedFeatures); + this.declaredFeatures = List.copyOf(declaredFeatures); + } + + public List unsupportedFeatures() { + return unsupportedFeatures; + } + + public List declaredFeatures() { + return declaredFeatures; + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/error/IdempotencyConflictError.java b/adcp/src/main/java/org/adcontextprotocol/adcp/error/IdempotencyConflictError.java new file mode 100644 index 0000000..daa591f --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/error/IdempotencyConflictError.java @@ -0,0 +1,12 @@ +package org.adcontextprotocol.adcp.error; + +/** An idempotency key collided with an in-flight or completed request. */ +public final class IdempotencyConflictError extends AdcpError { + + @java.io.Serial + private static final long serialVersionUID = 1L; + + public IdempotencyConflictError(String message) { + super("IDEMPOTENCY_CONFLICT", message, null); + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/error/IdempotencyExpiredError.java b/adcp/src/main/java/org/adcontextprotocol/adcp/error/IdempotencyExpiredError.java new file mode 100644 index 0000000..7e1956e --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/error/IdempotencyExpiredError.java @@ -0,0 +1,12 @@ +package org.adcontextprotocol.adcp.error; + +/** An idempotency key has expired (TTL exceeded). */ +public final class IdempotencyExpiredError extends AdcpError { + + @java.io.Serial + private static final long serialVersionUID = 1L; + + public IdempotencyExpiredError(String message) { + super("IDEMPOTENCY_EXPIRED", message, null); + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/error/ProtocolError.java b/adcp/src/main/java/org/adcontextprotocol/adcp/error/ProtocolError.java new file mode 100644 index 0000000..2517cf8 --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/error/ProtocolError.java @@ -0,0 +1,22 @@ +package org.adcontextprotocol.adcp.error; + +import org.jspecify.annotations.Nullable; + +/** Wraps MCP or A2A transport failures. */ +public final class ProtocolError extends AdcpError { + + @java.io.Serial + private static final long serialVersionUID = 1L; + + private final String protocol; + + public ProtocolError(String protocol, String message, @Nullable Throwable cause) { + super("PROTOCOL_ERROR", message, null, cause); + this.protocol = protocol; + } + + /** {@code "mcp"} or {@code "a2a"}. */ + public String protocol() { + return protocol; + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/error/ResponseTooLargeError.java b/adcp/src/main/java/org/adcontextprotocol/adcp/error/ResponseTooLargeError.java new file mode 100644 index 0000000..c5441d0 --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/error/ResponseTooLargeError.java @@ -0,0 +1,38 @@ +package org.adcontextprotocol.adcp.error; + +import org.jspecify.annotations.Nullable; + +import java.net.URI; + +/** Response body exceeded the configured maximum size. */ +public final class ResponseTooLargeError extends AdcpError { + + @java.io.Serial + private static final long serialVersionUID = 1L; + + private final long limit; + private final long bytesRead; + private final @Nullable URI url; + + public ResponseTooLargeError(long limit, long bytesRead, @Nullable URI url) { + super("RESPONSE_TOO_LARGE", + "Response exceeded " + limit + " bytes (read " + bytesRead + ")" + + (url != null ? " from " + url : ""), + null); + this.limit = limit; + this.bytesRead = bytesRead; + this.url = url; + } + + public long limit() { + return limit; + } + + public long bytesRead() { + return bytesRead; + } + + public @Nullable URI url() { + return url; + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/error/TaskAbortedError.java b/adcp/src/main/java/org/adcontextprotocol/adcp/error/TaskAbortedError.java new file mode 100644 index 0000000..c467ef3 --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/error/TaskAbortedError.java @@ -0,0 +1,24 @@ +package org.adcontextprotocol.adcp.error; + +import org.jspecify.annotations.Nullable; + +/** A task was aborted by the caller or the agent. */ +public final class TaskAbortedError extends AdcpError { + + @java.io.Serial + private static final long serialVersionUID = 1L; + + private final String taskId; + + public TaskAbortedError(String taskId, @Nullable String reason) { + super("TASK_ABORTED", + "Task aborted: " + taskId + + (reason != null ? " (" + reason + ")" : ""), + null); + this.taskId = taskId; + } + + public String taskId() { + return taskId; + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/error/TaskTimeoutError.java b/adcp/src/main/java/org/adcontextprotocol/adcp/error/TaskTimeoutError.java new file mode 100644 index 0000000..5435b68 --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/error/TaskTimeoutError.java @@ -0,0 +1,30 @@ +package org.adcontextprotocol.adcp.error; + +import org.jspecify.annotations.Nullable; + +/** A task exceeded its working timeout. */ +public final class TaskTimeoutError extends AdcpError { + + @java.io.Serial + private static final long serialVersionUID = 1L; + + private final @Nullable String taskId; + private final long timeoutMs; + + public TaskTimeoutError(@Nullable String taskId, long timeoutMs) { + super("TASK_TIMEOUT", + "Task timed out after " + timeoutMs + "ms" + + (taskId != null ? " (taskId=" + taskId + ")" : ""), + null); + this.taskId = taskId; + this.timeoutMs = timeoutMs; + } + + public @Nullable String taskId() { + return taskId; + } + + public long timeoutMs() { + return timeoutMs; + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/error/UnsupportedTaskError.java b/adcp/src/main/java/org/adcontextprotocol/adcp/error/UnsupportedTaskError.java new file mode 100644 index 0000000..944e007 --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/error/UnsupportedTaskError.java @@ -0,0 +1,21 @@ +package org.adcontextprotocol.adcp.error; + +/** The agent does not support the requested tool/task. */ +public final class UnsupportedTaskError extends AdcpError { + + @java.io.Serial + private static final long serialVersionUID = 1L; + + private final String taskName; + + public UnsupportedTaskError(String taskName) { + super("UNSUPPORTED_TASK", + "Unsupported task: " + taskName, + null); + this.taskName = taskName; + } + + public String taskName() { + return taskName; + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/error/ValidationError.java b/adcp/src/main/java/org/adcontextprotocol/adcp/error/ValidationError.java new file mode 100644 index 0000000..214307d --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/error/ValidationError.java @@ -0,0 +1,44 @@ +package org.adcontextprotocol.adcp.error; + +import org.jspecify.annotations.Nullable; + +import java.util.List; + +/** + * Request or response validation failed. + * + *

The {@link #path()} carries a JSON-pointer path as a list of segments + * (e.g. {@code ["products", "3", "formats", "0", "duration"]}), matching + * the TS SDK's wire format. The {@link #schemaUri()} is the failing + * schema's {@code $id} when available. + */ +public final class ValidationError extends AdcpError { + + @java.io.Serial + private static final long serialVersionUID = 1L; + + @SuppressWarnings("serial") // List.copyOf() returns a Serializable impl + private final List path; + private final @Nullable String schemaUri; + + public ValidationError(String message, @Nullable List path, + @Nullable String schemaUri) { + super("VALIDATION_ERROR", message, null); + this.path = path != null ? List.copyOf(path) : List.of(); + this.schemaUri = schemaUri; + } + + public ValidationError(String message, @Nullable String field) { + this(message, field != null ? List.of(field) : null, null); + } + + /** JSON-pointer path to the failing field (may be empty). */ + public List path() { + return path; + } + + /** The {@code $id} of the schema that failed validation, if available. */ + public @Nullable String schemaUri() { + return schemaUri; + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/error/VersionUnsupportedError.java b/adcp/src/main/java/org/adcontextprotocol/adcp/error/VersionUnsupportedError.java new file mode 100644 index 0000000..1f1aaf9 --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/error/VersionUnsupportedError.java @@ -0,0 +1,49 @@ +package org.adcontextprotocol.adcp.error; + +import org.jspecify.annotations.Nullable; + +import java.net.URI; + +/** The agent does not support the requested protocol version. */ +public final class VersionUnsupportedError extends AdcpError { + + @java.io.Serial + private static final long serialVersionUID = 1L; + + private final @Nullable String taskType; + private final String reason; + private final @Nullable String actualVersion; + private final @Nullable URI agentUri; + + public VersionUnsupportedError( + @Nullable String taskType, + String reason, + @Nullable String actualVersion, + @Nullable URI agentUri) { + super("VERSION_UNSUPPORTED", + "Version unsupported: " + reason + + (taskType != null ? " (task=" + taskType + ")" : ""), + null); + this.taskType = taskType; + this.reason = reason; + this.actualVersion = actualVersion; + this.agentUri = agentUri; + } + + public @Nullable String taskType() { + return taskType; + } + + /** {@code "version"}, {@code "idempotency"}, or a domain-specific reason. */ + public String reason() { + return reason; + } + + public @Nullable String actualVersion() { + return actualVersion; + } + + public @Nullable URI agentUri() { + return agentUri; + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/error/package-info.java b/adcp/src/main/java/org/adcontextprotocol/adcp/error/package-info.java new file mode 100644 index 0000000..d1fdb6c --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/error/package-info.java @@ -0,0 +1,5 @@ +/** + * AdCP error types. All SDK errors extend {@link org.adcontextprotocol.adcp.error.AdcpError}. + */ +@org.jspecify.annotations.NullMarked +package org.adcontextprotocol.adcp.error; diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/http/AdcpHttpClient.java b/adcp/src/main/java/org/adcontextprotocol/adcp/http/AdcpHttpClient.java new file mode 100644 index 0000000..dab20fd --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/http/AdcpHttpClient.java @@ -0,0 +1,383 @@ +package org.adcontextprotocol.adcp.http; + +import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.InetAddress; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.Map; +import java.util.Objects; + +/** + * SSRF-safe HTTP client for AdCP outbound requests. + * + *

Implements the four mitigations from {@code specs/ssrf-baseline.md}: + *

    + *
  1. Resolve DNS once, validate the full address set
  2. + *
  3. Validate resolved addresses against the SSRF policy
  4. + *
  5. {@code redirect: manual} (no transparent redirect-follow)
  6. + *
  7. Body cap (default 4 KiB for probes, configurable per call)
  8. + *
+ * + *

Every outbound HTTP call in the SDK routes through this client. + * Built on {@link java.net.http.HttpClient} (JDK 21). + * + *

The client keeps the original URI authority unchanged so HTTPS uses the + * intended hostname for TLS SNI and hostname verification. DNS validation + * therefore happens at resolve time only and accepts the remaining TOCTOU + * window before connect. + * + * @see SsrfPolicy + * @see DnsPinResolver + */ +public final class AdcpHttpClient implements AutoCloseable { + + private static final Logger log = LoggerFactory.getLogger(AdcpHttpClient.class); + + /** Default body cap for discovery probes (4 KiB). */ + public static final long DEFAULT_MAX_RESPONSE_BYTES = 4 * 1024; + + private static final Duration DEFAULT_CONNECT_TIMEOUT = Duration.ofSeconds(10); + private static final Duration DEFAULT_READ_TIMEOUT = Duration.ofSeconds(30); + private static final String DEFAULT_USER_AGENT = "adcp-java-sdk/0.1"; + + private final SsrfPolicy ssrfPolicy; + private final long maxResponseBytes; + private final Duration connectTimeout; + private final Duration readTimeout; + private final String userAgent; + private final boolean requireHttps; + private final HttpClient httpClient; + + private AdcpHttpClient(Builder builder) { + this.ssrfPolicy = builder.ssrfPolicy; + this.maxResponseBytes = builder.maxResponseBytes; + this.connectTimeout = builder.connectTimeout; + this.readTimeout = builder.readTimeout; + this.userAgent = builder.userAgent; + this.requireHttps = builder.requireHttps; + this.httpClient = HttpClient.newBuilder() + .connectTimeout(this.connectTimeout) + .followRedirects(HttpClient.Redirect.NEVER) + .build(); + } + + /** Creates a new builder with strict SSRF policy defaults. */ + public static Builder builder() { + return new Builder(); + } + + /** + * Sends an HTTP request with SSRF protection. + * + *

The hostname is resolved via DNS and all addresses are validated + * against the {@link SsrfPolicy}. The original URI is preserved so TLS + * SNI and hostname verification continue to use the hostname instead of + * an IP literal. Redirects are never followed automatically. The response + * body is capped at {@link #maxResponseBytes()}. + * + * @param method HTTP method (GET, POST, etc.) + * @param uri target URI + * @param headers additional headers to include + * @param body request body, or {@code null} for bodyless requests + * @return the response with possible body truncation + * @throws SsrfBlockedException if the target address is blocked + * @throws IOException on transport errors + * @throws InterruptedException if the calling thread is interrupted + */ + public AdcpHttpResponse send( + String method, + URI uri, + Map headers, + @Nullable byte[] body) throws IOException, InterruptedException { + + Objects.requireNonNull(uri, "uri"); + Objects.requireNonNull(method, "method"); + + // Step 0: Enforce HTTPS when requireHttps is enabled. + // Localhost/loopback is exempt for local development. + if (requireHttps && "http".equalsIgnoreCase(uri.getScheme())) { + String host = uri.getHost(); + if (host != null && !isLoopback(host)) { + throw new IOException( + "Plain HTTP is not allowed when requireHttps is enabled: " + uri + + ". Use HTTPS or set requireHttps(false) for local development."); + } + } + + // Step 1: DNS resolve + SSRF validate. Keep the original URI so + // HTTPS continues to use the hostname for TLS and hostname checks, + // accepting the remaining resolve-to-connect TOCTOU window. + URI validatedUri = validateUri(uri); + + // Step 2: Build the request with the validated URI + HttpRequest.Builder requestBuilder = HttpRequest.newBuilder() + .uri(validatedUri) + .timeout(readTimeout) + .header("User-Agent", userAgent); + + // Add caller-supplied headers, skipping protected headers + headers.forEach((name, value) -> { + if (!isProtectedHeader(name)) { + requestBuilder.header(name, value); + } else { + log.debug("Skipping protected header from caller: {}", name); + } + }); + + // Set method + body + if (body != null) { + requestBuilder.method(method, HttpRequest.BodyPublishers.ofByteArray(body)); + } else { + requestBuilder.method(method, HttpRequest.BodyPublishers.noBody()); + } + + // Step 3: Send with body-capped response handler + HttpResponse response = httpClient.send( + requestBuilder.build(), + HttpResponse.BodyHandlers.ofInputStream()); + + // Step 4: Read body with cap enforcement. + // Ensure the InputStream is closed even if readBodyWithCap throws. + try { + return readBodyWithCap(response); + } catch (Throwable t) { + try { + response.body().close(); + } catch (Exception suppressed) { + t.addSuppressed(suppressed); + } + throw t; + } + } + + /** + * Convenience: GET request with no body. + */ + public AdcpHttpResponse get(URI uri, Map headers) + throws IOException, InterruptedException { + return send("GET", uri, headers, null); + } + + /** + * Convenience: POST request with a body. + */ + public AdcpHttpResponse post(URI uri, Map headers, byte[] body) + throws IOException, InterruptedException { + return send("POST", uri, headers, body); + } + + /** The SSRF policy in use. */ + public SsrfPolicy ssrfPolicy() { + return ssrfPolicy; + } + + /** Maximum response body size in bytes. */ + public long maxResponseBytes() { + return maxResponseBytes; + } + + /** + * Creates an MCP transport client builder with the same connection-timeout + * and redirect policy used by this client. + */ + public HttpClient.Builder newMcpClientBuilder() { + return HttpClient.newBuilder() + .connectTimeout(connectTimeout) + .followRedirects(HttpClient.Redirect.NEVER); + } + + @Override + public void close() { + httpClient.close(); + } + + // -- internal -- + + private URI validateUri(URI uri) throws IOException { + String scheme = uri.getScheme(); + if (!"http".equalsIgnoreCase(scheme) && !"https".equalsIgnoreCase(scheme)) { + throw new IOException("URI scheme must be http or https: " + uri); + } + + String host = uri.getHost(); + if (host == null) { + throw new IOException("URI has no host: " + uri); + } + + // Syntactic check for IP literals + if (isIpLiteral(host)) { + InetAddress literal = InetAddress.getByName(host); + DnsPinResolver.validateAddress(literal, ssrfPolicy); + return uri; + } + + // Resolve hostname and validate every address, but keep the original + // hostname in the URI so TLS SNI and hostname verification work. + // This accepts the remaining TOCTOU window between validation and connect. + DnsPinResolver.resolveAndPin(host, ssrfPolicy); + return uri; + } + + private static boolean isIpLiteral(String host) { + // IPv6 in URI brackets: [::1] + if (host.startsWith("[")) { + return true; + } + // Must have at least one dot for IPv4 + if (host.indexOf('.') < 0) { + return false; + } + // Check: all characters are digits and dots + for (int i = 0; i < host.length(); i++) { + char c = host.charAt(i); + if (c != '.' && (c < '0' || c > '9')) { + return false; + } + } + // Reject ambiguous octal/decimal literals (e.g. 0177.0.0.1). + // InetAddress.getAllByName may interpret leading-zero octets as + // octal on some JDKs, allowing SSRF bypass. + for (String octet : host.split("\\.", -1)) { + if (octet.length() > 1 && octet.startsWith("0")) { + throw new org.adcontextprotocol.adcp.http.SsrfBlockedException( + host, "Ambiguous IP literal with leading zeros (possible octal)"); + } + } + return true; + } + + private static boolean isLoopback(String host) { + return "localhost".equalsIgnoreCase(host) + || "127.0.0.1".equals(host) + || "[::1]".equals(host) + || "::1".equals(host); + } + + private AdcpHttpResponse readBodyWithCap(HttpResponse response) + throws IOException { + long cap = maxResponseBytes; + boolean truncated = false; + long totalRead = 0; + + try (InputStream is = response.body()) { + ByteArrayOutputStream baos = new ByteArrayOutputStream( + (int) Math.min(cap, 8192)); + byte[] buf = new byte[8192]; + int n; + while ((n = is.read(buf)) != -1) { + totalRead += n; + if (totalRead <= cap) { + baos.write(buf, 0, n); + } else if (!truncated) { + // Write only the portion up to the cap + int remaining = (int) (cap - (totalRead - n)); + if (remaining > 0) { + baos.write(buf, 0, remaining); + } + truncated = true; + log.debug("Response body truncated at {} bytes (cap={})", + totalRead, cap); + // Stop reading — don't consume the rest + break; + } + } + + return new AdcpHttpResponse( + response.statusCode(), + response.headers(), + baos.toByteArray(), + truncated, + totalRead); + } + } + + private static boolean isProtectedHeader(String name) { + return ProtectedHeaders.isProtected(name); + } + + // -- Builder -- + + public static final class Builder { + private SsrfPolicy ssrfPolicy = SsrfPolicy.strict(); + private long maxResponseBytes = DEFAULT_MAX_RESPONSE_BYTES; + private Duration connectTimeout = DEFAULT_CONNECT_TIMEOUT; + private Duration readTimeout = DEFAULT_READ_TIMEOUT; + private String userAgent = DEFAULT_USER_AGENT; + private boolean requireHttps = false; + + private Builder() {} + + /** + * Sets the SSRF policy. Defaults to {@link SsrfPolicy#strict()}. + * Use {@link SsrfPolicy#permissive()} only for local development + * against {@code localhost}. + */ + public Builder ssrfPolicy(SsrfPolicy ssrfPolicy) { + this.ssrfPolicy = Objects.requireNonNull(ssrfPolicy); + return this; + } + + /** + * Maximum response body size in bytes. Responses exceeding this + * are truncated and flagged via {@link AdcpHttpResponse#truncated()}. + * Default: {@value #DEFAULT_MAX_RESPONSE_BYTES} (4 KiB). + * Maximum: 64 MB. + */ + public Builder maxResponseBytes(long maxResponseBytes) { + if (maxResponseBytes <= 0 || maxResponseBytes > 64 * 1024 * 1024) { + throw new IllegalArgumentException( + "maxResponseBytes must be in (0, 67108864]: " + maxResponseBytes); + } + this.maxResponseBytes = maxResponseBytes; + return this; + } + + /** Connection timeout. Default: 10 seconds. */ + public Builder connectTimeout(Duration connectTimeout) { + this.connectTimeout = Objects.requireNonNull(connectTimeout); + return this; + } + + /** Read timeout per request. Default: 30 seconds. */ + public Builder readTimeout(Duration readTimeout) { + this.readTimeout = Objects.requireNonNull(readTimeout); + return this; + } + + /** User-Agent header value. */ + public Builder userAgent(String userAgent) { + this.userAgent = Objects.requireNonNull(userAgent); + return this; + } + + /** + * When {@code true}, rejects plain {@code http://} URIs for + * non-loopback hosts. Prevents credential leakage over unencrypted + * connections in production. + * + *

Localhost ({@code 127.0.0.1}, {@code ::1}, {@code localhost}) + * is always exempt for local development. + * + *

Default: {@code false} (warns only, via + * {@link org.adcontextprotocol.adcp.AgentConfig}). + */ + public Builder requireHttps(boolean requireHttps) { + this.requireHttps = requireHttps; + return this; + } + + /** Builds the client. */ + public AdcpHttpClient build() { + return new AdcpHttpClient(this); + } + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/http/AdcpHttpResponse.java b/adcp/src/main/java/org/adcontextprotocol/adcp/http/AdcpHttpResponse.java new file mode 100644 index 0000000..cde5f83 --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/http/AdcpHttpResponse.java @@ -0,0 +1,78 @@ +package org.adcontextprotocol.adcp.http; + +import org.jspecify.annotations.Nullable; + +import java.net.http.HttpHeaders; +import java.util.Map; + +/** + * Response from {@link AdcpHttpClient#send}. Wraps the status code, + * headers, and body — with truncation tracking when the body cap + * is exceeded. + * + *

Note: Equality comparison is not meaningful for this record + * because it contains a byte array field. Use explicit content comparison + * via {@link java.util.Arrays#equals(byte[], byte[])} if needed. + * + * @param statusCode HTTP status code + * @param headers response headers + * @param body response body (possibly truncated) + * @param truncated {@code true} if the body was truncated at the configured cap + * @param bytesRead total bytes read before truncation (or full body length) + */ +public record AdcpHttpResponse( + int statusCode, + HttpHeaders headers, + byte[] body, + boolean truncated, + long bytesRead +) { + + /** + * Defensive copy on construction to prevent callers who retain + * a reference to the input array from mutating response state. + */ + public AdcpHttpResponse { + body = body.clone(); + } + + /** + * Returns a defensive copy of the body bytes. + * Callers may freely mutate the returned array. + */ + @Override + public byte[] body() { + return body.clone(); + } + + /** Returns the body as a UTF-8 string (from the internal copy, no extra clone). */ + public String bodyAsString() { + return new String(body, java.nio.charset.StandardCharsets.UTF_8); + } + + /** Returns the value of a single header, or {@code null} if absent. */ + public @Nullable String header(String name) { + return headers.firstValue(name).orElse(null); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof AdcpHttpResponse that)) return false; + return statusCode == that.statusCode + && truncated == that.truncated + && bytesRead == that.bytesRead + && java.util.Arrays.equals(body, that.body) + && headers.equals(that.headers); + } + + @Override + public int hashCode() { + int h = Integer.hashCode(statusCode); + h = 31 * h + headers.hashCode(); + h = 31 * h + java.util.Arrays.hashCode(body); + h = 31 * h + Boolean.hashCode(truncated); + h = 31 * h + Long.hashCode(bytesRead); + return h; + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/http/DnsPinResolver.java b/adcp/src/main/java/org/adcontextprotocol/adcp/http/DnsPinResolver.java new file mode 100644 index 0000000..55d1edd --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/http/DnsPinResolver.java @@ -0,0 +1,59 @@ +package org.adcontextprotocol.adcp.http; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.UnknownHostException; + +/** + * DNS validator that resolves a hostname once, validates every returned + * address against an {@link SsrfPolicy}, and returns the first validated + * address. + * + *

Resolution uses {@link InetAddress#getAllByName(String)} (the system + * resolver). Callers keep the original URI authority unchanged so TLS SNI + * and hostname verification continue to use the hostname instead of an IP + * literal. This means validation happens at resolve time only and callers + * must accept the remaining TOCTOU window between DNS validation and the + * eventual connect. + */ +public final class DnsPinResolver { + + private DnsPinResolver() {} + + /** + * Resolves {@code host} and validates every returned address against + * the given {@link SsrfPolicy}. Returns the first validated address. + * + * @throws SsrfBlockedException if any resolved address is denied + * @throws UnknownHostException if the host cannot be resolved + */ + public static InetAddress resolveAndPin(String host, SsrfPolicy policy) throws IOException { + InetAddress[] addresses = InetAddress.getAllByName(host); + if (addresses.length == 0) { + throw new UnknownHostException("No addresses resolved for: " + host); + } + + for (InetAddress addr : addresses) { + SsrfDecision decision = policy.evaluate(addr); + if (decision instanceof SsrfDecision.Deny deny) { + throw new SsrfBlockedException(host, deny.reason()); + } + } + + return addresses[0]; + } + + /** + * Validates a literal IP address (no DNS resolution needed) against + * the policy. + * + * @throws SsrfBlockedException if the address is denied + */ + public static void validateAddress(InetAddress address, SsrfPolicy policy) { + SsrfDecision decision = policy.evaluate(address); + if (decision instanceof SsrfDecision.Deny deny) { + throw new SsrfBlockedException( + address.getHostAddress(), deny.reason()); + } + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/http/ProtectedHeaders.java b/adcp/src/main/java/org/adcontextprotocol/adcp/http/ProtectedHeaders.java new file mode 100644 index 0000000..699e1d7 --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/http/ProtectedHeaders.java @@ -0,0 +1,26 @@ +package org.adcontextprotocol.adcp.http; + +import java.util.Locale; +import java.util.Set; + +/** + * Headers that must never be overridden by caller-supplied values. + * + *

Shared between {@link AdcpHttpClient} and the MCP transport layer + * to prevent duplication drift. + */ +public final class ProtectedHeaders { + + /** Headers that SDK-managed transports must not allow callers to set. */ + public static final Set NAMES = Set.of( + "host", "user-agent", "content-length", "transfer-encoding", + "connection", "upgrade", + "authorization", "cookie", "proxy-authorization"); + + private ProtectedHeaders() {} + + /** Returns {@code true} if the given header name is protected (case-insensitive). */ + public static boolean isProtected(String name) { + return NAMES.contains(name.toLowerCase(Locale.ROOT)); + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/http/SsrfBlockedException.java b/adcp/src/main/java/org/adcontextprotocol/adcp/http/SsrfBlockedException.java new file mode 100644 index 0000000..497eb0f --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/http/SsrfBlockedException.java @@ -0,0 +1,33 @@ +package org.adcontextprotocol.adcp.http; + +/** + * Thrown when an outbound request is blocked by the {@link SsrfPolicy}. + * + *

The {@link #reason()} describes the blocked range (e.g. "loopback", + * "RFC 1918 private") without echoing the actual address, to avoid + * leaking host structure to callers. + */ +public final class SsrfBlockedException extends RuntimeException { + + @java.io.Serial + private static final long serialVersionUID = 1L; + + private final String host; + private final String reason; + + SsrfBlockedException(String host, String reason) { + super("SSRF blocked: " + reason); + this.host = host; + this.reason = reason; + } + + /** The hostname or IP that was blocked. Package-private to limit exposure. */ + String host() { + return host; + } + + /** Why the address was blocked (range description, not the address). */ + public String reason() { + return reason; + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/http/StrictSsrfPolicy.java b/adcp/src/main/java/org/adcontextprotocol/adcp/http/StrictSsrfPolicy.java index 2529f84..64352c1 100644 --- a/adcp/src/main/java/org/adcontextprotocol/adcp/http/StrictSsrfPolicy.java +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/http/StrictSsrfPolicy.java @@ -53,32 +53,51 @@ public SsrfDecision evaluate(InetAddress address) { return new SsrfDecision.Deny("reserved (240.0.0.0/4)"); } } - if (effective instanceof Inet6Address v6 && isIpv6UniqueLocal(v6)) { - return new SsrfDecision.Deny("IPv6 unique local (fc00::/7)"); + if (effective instanceof Inet6Address v6) { + if (isIpv6UniqueLocal(v6)) { + return new SsrfDecision.Deny("IPv6 unique local (fc00::/7)"); + } + if (is6to4(v6)) { + return new SsrfDecision.Deny("6to4 relay (2002::/16) embedding private IPv4"); + } + if (isTeredo(v6)) { + return new SsrfDecision.Deny("Teredo (2001:0000::/32) embedding private IPv4"); + } + if (isNat64(v6)) { + return new SsrfDecision.Deny("NAT64 well-known (64:ff9b::/96) embedding private IPv4"); + } } return SsrfDecision.ALLOW; } private static InetAddress unmapIpv4Mapped(InetAddress address) { - // ::ffff:0:0/96 — an IPv4 address tunneled inside an IPv6 address. - // The JDK's range methods evaluate the v6 form, not the embedded v4, - // so we unwrap to apply the v4 ranges (RFC 1918 etc.) to the - // effective destination. - // - // Note: Inet6Address.isIPv4CompatibleAddress() checks the legacy - // "::a.b.c.d" form (which also matches ::1), not the IPv4-mapped - // "::ffff:a.b.c.d" form we want. We test the bytes directly. + // Unwrap both IPv4-mapped (::ffff:a.b.c.d) and IPv4-compatible + // (::a.b.c.d) IPv6 addresses so that the embedded IPv4 address + // gets evaluated against the IPv4 block ranges. The compatible + // form is deprecated (RFC 4291 §2.5.5.1) but still parsed by + // JDK's InetAddress, and JDK's range methods (isLoopback, etc.) + // return false for these addresses — making them an SSRF vector. if (!(address instanceof Inet6Address v6)) { return address; } byte[] addr = v6.getAddress(); - // First 80 bits zero, next 16 bits 0xFFFF — the IPv4-mapped form. + // First 80 bits must be zero (common to both forms) for (int i = 0; i < 10; i++) { if (addr[i] != 0) { return address; } } - if ((addr[10] & 0xFF) != 0xFF || (addr[11] & 0xFF) != 0xFF) { + // IPv4-mapped: bytes 10-11 = 0xFF, 0xFF + boolean isMapped = (addr[10] & 0xFF) == 0xFF && (addr[11] & 0xFF) == 0xFF; + // IPv4-compatible: bytes 10-11 = 0x00, 0x00 (and not all-zeros/::1) + boolean isCompat = addr[10] == 0 && addr[11] == 0; + if (!isMapped && !isCompat) { + return address; + } + // Guard: don't unwrap :: (all zeros) or ::1 — those are already + // handled by isAnyLocalAddress() / isLoopbackAddress() + if (isCompat && addr[12] == 0 && addr[13] == 0 + && addr[14] == 0 && (addr[15] == 0 || addr[15] == 1)) { return address; } byte[] v4Bytes = new byte[]{addr[12], addr[13], addr[14], addr[15]}; @@ -118,4 +137,69 @@ private static boolean isIpv6UniqueLocal(Inet6Address v6) { // fc00::/7 — the first byte is 0xFC or 0xFD. return firstByte == 0xFC || firstByte == 0xFD; } + + /** + * 6to4 (2002::/16) — embeds an IPv4 address in bytes 2-5. + * A 6to4 address embedding a private IPv4 (e.g. 2002:7f00:0001:: → 127.0.0.1) + * is an SSRF vector. + */ + private boolean is6to4(Inet6Address v6) { + byte[] b = v6.getAddress(); + if ((b[0] & 0xFF) != 0x20 || (b[1] & 0xFF) != 0x02) { + return false; + } + // Extract embedded IPv4 from bytes 2-5 + byte[] embedded = new byte[]{b[2], b[3], b[4], b[5]}; + try { + InetAddress embeddedV4 = InetAddress.getByAddress(embedded); + return evaluate(embeddedV4) instanceof SsrfDecision.Deny; + } catch (Exception e) { + return true; // fail-closed + } + } + + /** + * Teredo (2001:0000::/32) — embeds an obfuscated IPv4 in the last 4 bytes + * (XOR'd with 0xFF). Block if the decoded IPv4 is private. + */ + private boolean isTeredo(Inet6Address v6) { + byte[] b = v6.getAddress(); + if ((b[0] & 0xFF) != 0x20 || (b[1] & 0xFF) != 0x01 + || b[2] != 0 || b[3] != 0) { + return false; + } + // Teredo client IPv4 is in bytes 12-15, XOR'd with 0xFF + byte[] embedded = new byte[]{ + (byte) (~b[12] & 0xFF), (byte) (~b[13] & 0xFF), + (byte) (~b[14] & 0xFF), (byte) (~b[15] & 0xFF)}; + try { + InetAddress embeddedV4 = InetAddress.getByAddress(embedded); + return evaluate(embeddedV4) instanceof SsrfDecision.Deny; + } catch (Exception e) { + return true; + } + } + + /** + * NAT64 well-known prefix (64:ff9b::/96) — embeds an IPv4 in the + * last 4 bytes. Block if the embedded IPv4 is private. + */ + private boolean isNat64(Inet6Address v6) { + byte[] b = v6.getAddress(); + // 64:ff9b:: → 0x00, 0x64, 0xff, 0x9b, then 8 zero bytes + if (b[0] != 0x00 || (b[1] & 0xFF) != 0x64 + || (b[2] & 0xFF) != 0xFF || (b[3] & 0xFF) != 0x9B) { + return false; + } + for (int i = 4; i < 12; i++) { + if (b[i] != 0) return false; + } + byte[] embedded = new byte[]{b[12], b[13], b[14], b[15]}; + try { + InetAddress embeddedV4 = InetAddress.getByAddress(embedded); + return evaluate(embeddedV4) instanceof SsrfDecision.Deny; + } catch (Exception e) { + return true; + } + } } diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/schema/AdcpObjectMapperFactory.java b/adcp/src/main/java/org/adcontextprotocol/adcp/schema/AdcpObjectMapperFactory.java index 47d6c53..73ac477 100644 --- a/adcp/src/main/java/org/adcontextprotocol/adcp/schema/AdcpObjectMapperFactory.java +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/schema/AdcpObjectMapperFactory.java @@ -24,11 +24,11 @@ public final class AdcpObjectMapperFactory { private AdcpObjectMapperFactory() {} - /** Maximum string length for AdCP payloads (100 MB). */ - private static final int MAX_STRING_LENGTH = 100_000_000; + /** Maximum string length for AdCP payloads (10 MB). */ + private static final int MAX_STRING_LENGTH = 10_000_000; /** Maximum nesting depth for AdCP catalog responses. */ - private static final int MAX_NESTING_DEPTH = 2000; + private static final int MAX_NESTING_DEPTH = 200; /** * Creates a new {@link ObjectMapper} configured for AdCP payloads. diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/transport/CallToolOptions.java b/adcp/src/main/java/org/adcontextprotocol/adcp/transport/CallToolOptions.java new file mode 100644 index 0000000..0c6de00 --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/transport/CallToolOptions.java @@ -0,0 +1,64 @@ +package org.adcontextprotocol.adcp.transport; + +import org.jspecify.annotations.Nullable; + +import java.time.Duration; + +/** + * Options for a single {@code callTool()} invocation. + * + *

Implementation status: In v0.1, MCP transport applies a + * fixed 10 MB content limit regardless of {@code maxResponseBytes}. + * The {@code timeout} and {@code maxResponseBytes} fields are accepted + * for forward compatibility but are not yet wired into the + * MCP transport path. They will be enforced when the call-level timeout + * and per-agent body-cap features ship (planned v0.2). + * + * @param timeout per-call timeout (overrides client default) — reserved, not yet enforced + * @param maxResponseBytes per-call body cap (overrides client default) — reserved, not yet enforced + * @param validateResponse whether to validate the response against schema + */ +public record CallToolOptions( + @Nullable Duration timeout, + @Nullable Long maxResponseBytes, + boolean validateResponse +) { + + /** Default options: no timeout override, no body cap override, validation off. */ + public static final CallToolOptions DEFAULT = new CallToolOptions(null, null, false); + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + private @Nullable Duration timeout; + private @Nullable Long maxResponseBytes; + private boolean validateResponse; + + private Builder() {} + + public Builder timeout(Duration timeout) { + this.timeout = timeout; + return this; + } + + public Builder maxResponseBytes(long maxResponseBytes) { + if (maxResponseBytes <= 0) { + throw new IllegalArgumentException( + "maxResponseBytes must be positive: " + maxResponseBytes); + } + this.maxResponseBytes = maxResponseBytes; + return this; + } + + public Builder validateResponse(boolean validateResponse) { + this.validateResponse = validateResponse; + return this; + } + + public CallToolOptions build() { + return new CallToolOptions(timeout, maxResponseBytes, validateResponse); + } + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/transport/ProtocolClient.java b/adcp/src/main/java/org/adcontextprotocol/adcp/transport/ProtocolClient.java new file mode 100644 index 0000000..5076b1b --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/transport/ProtocolClient.java @@ -0,0 +1,270 @@ +package org.adcontextprotocol.adcp.transport; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpSyncClient; +import org.adcontextprotocol.adcp.AdcpVersion; +import org.adcontextprotocol.adcp.AgentConfig; +import org.adcontextprotocol.adcp.Protocol; +import org.adcontextprotocol.adcp.auth.AuthTokenResolver; +import org.adcontextprotocol.adcp.error.FeatureUnsupportedError; +import org.adcontextprotocol.adcp.error.ProtocolError; +import org.adcontextprotocol.adcp.http.SsrfPolicy; +import org.adcontextprotocol.adcp.transport.mcp.McpCaller; +import org.adcontextprotocol.adcp.transport.mcp.McpConnectionManager; +import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.charset.StandardCharsets; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.util.HexFormat; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import javax.crypto.Mac; +import javax.crypto.spec.SecretKeySpec; + +/** + * Dispatches tool calls to the appropriate transport (MCP or A2A). + * + *

This is the central dispatch point — all named tool methods in + * {@code AdcpClient} funnel through here. Mirrors the TS SDK's + * {@code ProtocolClient.callTool()}. + */ +public final class ProtocolClient implements AutoCloseable { + + private static final Logger log = LoggerFactory.getLogger(ProtocolClient.class); + + private final McpConnectionManager connectionManager; + private final McpCaller mcpCaller; + private final SsrfPolicy ssrfPolicy; + private final @Nullable AdcpVersion adcpVersion; + + /** + * Creates a new protocol client. + * + * @param objectMapper Jackson ObjectMapper for serialization + * @param ssrfPolicy SSRF policy for URL validation + * @param adcpVersion protocol version for the version envelope + * @param connectionManager MCP connection manager (shared) + */ + public ProtocolClient(ObjectMapper objectMapper, SsrfPolicy ssrfPolicy, + @Nullable AdcpVersion adcpVersion, + McpConnectionManager connectionManager) { + this.connectionManager = connectionManager; + this.mcpCaller = new McpCaller(objectMapper); + this.ssrfPolicy = ssrfPolicy; + this.adcpVersion = adcpVersion; + } + + /** + * Calls a tool on the given agent. + * + * @param agent the agent configuration + * @param toolName the tool name (e.g. "get_products") + * @param args tool arguments (caller-supplied) + * @param responseType expected response type + * @param options call options (timeout, validation, etc.) + * @param response type + * @return the deserialized response + */ + public T callTool(AgentConfig agent, String toolName, + Map args, Class responseType, + CallToolOptions options) { + + // 1. Check protocol support early so unsupported transports fail fast + if (agent.protocol() == org.adcontextprotocol.adcp.Protocol.A2A) { + throw new FeatureUnsupportedError( + List.of("A2A transport"), + List.of("MCP")); + } + + // 2. Validate agent URL against SSRF policy + validateUrl(agent); + + // 2. Warn if non-default options are passed (not yet enforced in v0.1) + if (!options.equals(CallToolOptions.DEFAULT)) { + log.debug("CallToolOptions fields are not yet enforced by the MCP transport (v0.1)"); + } + + // 3. Resolve auth headers + Map authHeaders = AuthTokenResolver.resolve(agent); + + // 4. Merge headers: filter extraHeaders through ProtectedHeaders first + // (so callers cannot override authorization/cookie/etc.), then add + // SDK-resolved auth headers which are trusted and must not be filtered. + Map allHeaders = new LinkedHashMap<>(); + agent.extraHeaders().forEach((name, value) -> { + if (!org.adcontextprotocol.adcp.http.ProtectedHeaders.isProtected(name)) { + allHeaders.put(name, value); + } else { + log.debug("Dropping protected extraHeader: {}", name); + } + }); + allHeaders.putAll(authHeaders); + + // 5. Build version envelope and merge into args + AdcpVersion version = agent.adcpVersion() != null ? agent.adcpVersion() : adcpVersion; + Map mergedArgs = VersionEnvelope.mergeInto(args, version); + + // 6. Dispatch to transport (A2A already rejected in step 1) + return callViaMcp(agent, toolName, mergedArgs, allHeaders, responseType); + } + + /** + * Convenience: calls a tool with default options. + */ + public T callTool(AgentConfig agent, String toolName, + Map args, Class responseType) { + return callTool(agent, toolName, args, responseType, CallToolOptions.DEFAULT); + } + + @Override + public void close() { + connectionManager.close(); + } + + private T callViaMcp(AgentConfig agent, String toolName, + Map mergedArgs, + Map headers, + Class responseType) { + String cacheHash = computeCacheHash(agent); + McpSyncClient client = connectionManager.getOrConnect( + agent.agentUri(), headers, cacheHash); + + try { + return mcpCaller.callTool(client, toolName, mergedArgs, responseType); + } catch (ProtocolError e) { + if (!isTransportError(e)) { + throw e; + } + // On transport error, evict and retry once + connectionManager.evict(agent.agentUri(), cacheHash); + log.debug("MCP transport error for {}, retrying after evict: {}", + toolName, e.getMessage()); + + ProtocolError original = e; + client = connectionManager.getOrConnect( + agent.agentUri(), headers, cacheHash); + try { + return mcpCaller.callTool(client, toolName, mergedArgs, responseType); + } catch (ProtocolError retry) { + retry.addSuppressed(original); + throw retry; + } + } + } + + private boolean isTransportError(ProtocolError e) { + // Walk the full cause chain — any I/O or timeout failure is transient + for (Throwable t = e.getCause(); t != null; t = t.getCause()) { + if (t instanceof java.io.IOException + || t instanceof java.net.http.HttpTimeoutException) { + return true; + } + } + return false; + } + + private void validateUrl(AgentConfig agent) { + String scheme = agent.agentUri().getScheme(); + if (!"http".equalsIgnoreCase(scheme) && !"https".equalsIgnoreCase(scheme)) { + throw new ProtocolError("mcp", + "Agent URI scheme must be http or https: " + agent.agentUri(), null); + } + String host = agent.agentUri().getHost(); + if (host == null) { + throw new ProtocolError("mcp", + "Agent URI has no host: " + agent.agentUri(), null); + } + // Resolve DNS and validate all addresses against SSRF policy. + // Probes are routed through AdcpHttpClient (which re-validates), + // but the MCP transport's underlying HttpClient still re-resolves + // DNS independently (TOCTOU limitation). This early check blocks + // the common case of misconfigured URIs pointing at private addresses. + try { + java.net.InetAddress[] addresses = java.net.InetAddress.getAllByName(host); + for (java.net.InetAddress addr : addresses) { + org.adcontextprotocol.adcp.http.DnsPinResolver.validateAddress( + addr, ssrfPolicy); + } + } catch (org.adcontextprotocol.adcp.http.SsrfBlockedException e) { + throw new ProtocolError("mcp", + "Agent URI blocked by SSRF policy", e); + } catch (java.net.UnknownHostException e) { + throw new ProtocolError("mcp", + "Cannot resolve agent host", e); + } + } + + /** + * Per-process HMAC key — prevents token hash reversibility in heap dumps. + * Generated once at class-load time; never persisted. + */ + private static final byte[] HMAC_KEY; + static { + HMAC_KEY = new byte[32]; + new SecureRandom().nextBytes(HMAC_KEY); + } + + /** + * Computes a combined hash of credentials + extraHeaders for use as + * a connection cache key. This ensures connections are not shared + * across different auth tokens or different routing headers. + */ + private static String computeCacheHash(AgentConfig agent) { + String tokenHash = computeTokenHash(agent); + if (agent.extraHeaders().isEmpty()) { + return tokenHash; + } + Mac mac = createHmac(); + mac.update(tokenHash.getBytes(StandardCharsets.UTF_8)); + mac.update((byte) '\0'); + agent.extraHeaders().entrySet().stream() + .sorted(Map.Entry.comparingByKey()) + .forEach(e -> { + mac.update(e.getKey().getBytes(StandardCharsets.UTF_8)); + mac.update((byte) '='); + mac.update(e.getValue().getBytes(StandardCharsets.UTF_8)); + mac.update((byte) '\n'); + }); + return HexFormat.of().formatHex(mac.doFinal()); + } + + /** + * Computes an HMAC-SHA256 of the agent's credentials for use as a + * cache key component. The per-process random HMAC key prevents + * reversal of known token formats (e.g. {@code ghp_*}) from heap dumps. + */ + static String computeTokenHash(AgentConfig agent) { + String token = ""; + if (agent.authToken() != null) { + token = agent.authToken(); + } else if (agent.oauthTokens() != null) { + token = agent.oauthTokens().accessToken(); + } else if (agent.basicAuth() != null) { + token = agent.basicAuth().username() + ":" + agent.basicAuth().password(); + } else if (agent.oauthClientCredentials() != null) { + token = "cc:" + agent.oauthClientCredentials().clientId(); + } + if (token.isEmpty()) { + return "anonymous"; + } + Mac mac = createHmac(); + return HexFormat.of().formatHex( + mac.doFinal(token.getBytes(StandardCharsets.UTF_8))); + } + + private static Mac createHmac() { + try { + Mac mac = Mac.getInstance("HmacSHA256"); + mac.init(new SecretKeySpec(HMAC_KEY, "HmacSHA256")); + return mac; + } catch (NoSuchAlgorithmException | InvalidKeyException e) { + throw new AssertionError("HmacSHA256 not available", e); + } + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/transport/VersionEnvelope.java b/adcp/src/main/java/org/adcontextprotocol/adcp/transport/VersionEnvelope.java new file mode 100644 index 0000000..5b5c1a6 --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/transport/VersionEnvelope.java @@ -0,0 +1,77 @@ +package org.adcontextprotocol.adcp.transport; + +import org.adcontextprotocol.adcp.AdcpVersion; +import org.jspecify.annotations.Nullable; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Builds the version envelope injected into every tool call's arguments. + * + *

Per the AdCP protocol, every request carries: + *

    + *
  • {@code adcp_major_version} — always present (e.g. {@code 3})
  • + *
  • {@code adcp_version} — present only when a specific minor version + * is pinned (e.g. {@code "3.1"})
  • + *
+ * + *

SDK-set version fields take precedence over caller args. If a caller + * attempts to override {@code adcp_major_version}, a warning is logged and + * the SDK value is used. + */ +public final class VersionEnvelope { + + private static final Logger log = LoggerFactory.getLogger(VersionEnvelope.class); + + private VersionEnvelope() {} + + /** + * Builds a version envelope map for the given protocol version. + * + * @param version the AdCP version (may be {@code null} for default v3) + * @return map with version fields + */ + public static Map build(@Nullable AdcpVersion version) { + AdcpVersion v = version != null ? version : AdcpVersion.V3; + Map envelope = new LinkedHashMap<>(); + envelope.put("adcp_major_version", v.majorVersion()); + if (v.minorVersion() != null) { + envelope.put("adcp_version", v.minorVersion()); + } + return envelope; + } + + /** + * Merges the version envelope into the tool call arguments. + * SDK version fields take precedence — caller overrides are logged + * as warnings and discarded. + * + * @param callerArgs the caller's arguments (may be empty or null) + * @param version the AdCP version + * @return merged arguments with version envelope + */ + public static Map mergeInto( + @Nullable Map callerArgs, + @Nullable AdcpVersion version) { + Map envelope = build(version); + Map merged = new LinkedHashMap<>(); + if (callerArgs != null) { + for (var entry : callerArgs.entrySet()) { + if (envelope.containsKey(entry.getKey())) { + log.warn("Caller attempted to override SDK version field '{}' " + + "(caller={}, SDK={}); SDK value wins", + entry.getKey(), entry.getValue(), + envelope.get(entry.getKey())); + } else { + merged.put(entry.getKey(), entry.getValue()); + } + } + } + merged.putAll(envelope); // SDK wins + return merged; + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/transport/mcp/McpCaller.java b/adcp/src/main/java/org/adcontextprotocol/adcp/transport/mcp/McpCaller.java new file mode 100644 index 0000000..1c3f98f --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/transport/mcp/McpCaller.java @@ -0,0 +1,157 @@ +package org.adcontextprotocol.adcp.transport.mcp; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.spec.McpSchema; +import org.adcontextprotocol.adcp.error.ProtocolError; +import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Map; + +/** + * Calls MCP tools via an established {@link McpSyncClient} connection. + * + *

Wraps the MCP SDK's {@code callTool()} API, extracting structured + * content and deserializing to the target response type. + */ +public final class McpCaller { + + private static final Logger log = LoggerFactory.getLogger(McpCaller.class); + + /** Maximum allowed TextContent length (10 MB, matching ObjectMapper limits). */ + private static final int MAX_CONTENT_LENGTH = 10 * 1024 * 1024; + + private final ObjectMapper objectMapper; + + public McpCaller(ObjectMapper objectMapper) { + // Harden the ObjectMapper against polymorphic deserialization attacks. + // When responseType is Object.class or Map.class, a default-typed + // mapper could instantiate arbitrary classes from incoming JSON + // (gadget-chain attacks). We defensively disable these features. + this.objectMapper = objectMapper.copy(); + this.objectMapper.deactivateDefaultTyping(); + } + + /** + * Calls an MCP tool and deserializes the response. + * + * @param client the connected MCP client + * @param toolName the MCP tool name (e.g. "get_products") + * @param args the merged arguments (including version envelope) + * @param responseType the expected response type + * @param response type + * @return the deserialized response + * @throws ProtocolError if the call fails or the response is unparseable + */ + public T callTool(McpSyncClient client, String toolName, + Map args, Class responseType) { + try { + McpSchema.CallToolRequest request = new McpSchema.CallToolRequest(toolName, args); + McpSchema.CallToolResult result = client.callTool(request); + + return extractResponse(result, responseType); + } catch (ProtocolError e) { + throw e; + } catch (Exception e) { + throw new ProtocolError("mcp", + "MCP callTool failed for " + toolName + ": " + e.getMessage(), e); + } + } + + @SuppressWarnings("unchecked") + private T extractResponse(McpSchema.CallToolResult result, Class responseType) { + // If the tool itself reported an error, surface it before trying + // to deserialize the content as a success payload. + if (Boolean.TRUE.equals(result.isError())) { + String errorText = extractErrorText(result); + throw new ProtocolError("mcp", "MCP tool returned an error: " + errorText, null); + } + + // MCP 2025-06-18: prefer structuredContent over content[] for typed payloads. + Object structured = result.structuredContent(); + if (structured != null) { + try { + JsonNode node = objectMapper.valueToTree(structured); + return objectMapper.treeToValue(node, responseType); + } catch (Exception e) { + log.debug("Failed to parse structuredContent as {}: {}", + responseType.getSimpleName(), e.getMessage()); + // Fall through to content[] path + } + } + + if (result.content() == null || result.content().isEmpty()) { + if (structured != null) { + throw new ProtocolError("mcp", + "Cannot deserialize structuredContent to " + + responseType.getSimpleName(), null); + } + throw new ProtocolError("mcp", "Empty response from MCP callTool", null); + } + + // Fall back to content[] TextContent path + Exception firstParseError = null; + for (McpSchema.Content content : result.content()) { + if (content instanceof McpSchema.TextContent textContent) { + String text = textContent.text(); + if (text == null) { + log.debug("Skipping TextContent with null text"); + continue; + } + if (text.length() > MAX_CONTENT_LENGTH) { + throw new ProtocolError("mcp", + "MCP response content exceeds size limit (" + + text.length() + " > " + + MAX_CONTENT_LENGTH + ")", null); + } + try { + return objectMapper.readValue(text, responseType); + } catch (Exception e) { + if (firstParseError == null) firstParseError = e; + log.debug("Failed to parse TextContent as {}: {}", + responseType.getSimpleName(), e.getMessage()); + } + } + } + + // If no parseable content found, try converting the first content item + McpSchema.Content first = result.content().getFirst(); + try { + JsonNode node = objectMapper.valueToTree(first); + return objectMapper.treeToValue(node, responseType); + } catch (Exception e) { + if (firstParseError != null) e.addSuppressed(firstParseError); + throw new ProtocolError("mcp", + "Cannot deserialize MCP response to " + responseType.getSimpleName(), + e); + } + } + + private static final int MAX_ERROR_LENGTH = 500; + + private String extractErrorText(McpSchema.CallToolResult result) { + if (result.content() != null) { + for (McpSchema.Content content : result.content()) { + if (content instanceof McpSchema.TextContent tc) { + return sanitizeErrorText(tc.text()); + } + } + } + return "(no error detail)"; + } + + private static String sanitizeErrorText(String raw) { + if (raw == null) { + return "(no error detail)"; + } + String truncated = raw.length() > MAX_ERROR_LENGTH + ? raw.substring(0, MAX_ERROR_LENGTH) + "..." + : raw; + // Strip control characters (except tab/newline) to prevent + // injection into downstream systems (logs, LLM context) + return truncated.replaceAll("[\\p{Cc}&&[^\t\n]]", ""); + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/transport/mcp/McpConnectionManager.java b/adcp/src/main/java/org/adcontextprotocol/adcp/transport/mcp/McpConnectionManager.java new file mode 100644 index 0000000..5e6e47b --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/transport/mcp/McpConnectionManager.java @@ -0,0 +1,459 @@ +package org.adcontextprotocol.adcp.transport.mcp; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpError; +import org.adcontextprotocol.adcp.auth.AuthChallengeInfo; +import org.adcontextprotocol.adcp.auth.WwwAuthenticateParser; +import org.adcontextprotocol.adcp.error.AuthenticationRequiredError; +import org.adcontextprotocol.adcp.error.ProtocolError; +import org.adcontextprotocol.adcp.http.AdcpHttpClient; +import org.adcontextprotocol.adcp.http.AdcpHttpResponse; +import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Semaphore; +import java.util.concurrent.locks.ReentrantLock; + +/** + * Manages cached MCP client connections with LRU eviction. + * + *

Cache key: {@code agentUrl::tokenHash}. Max {@value #MAX_CACHE_SIZE} entries. + * Implements StreamableHTTP → SSE fallback per TS SDK behavior. + * + *

Thread-safe: cache reads/writes use {@code cacheLock} (short-held, never + * during I/O). Connection establishment uses a fixed-size striped + * {@link Semaphore} pool so that: (a) only one thread connects per stripe, + * (b) different stripes proceed in parallel, and (c) virtual threads are not + * pinned during blocking network I/O. + * + *

LRU eviction note: An in-use client may be evicted by + * another thread's connection if the cache is full. The evicted client's + * in-flight call will fail with an IOException, which + * {@link org.adcontextprotocol.adcp.transport.ProtocolClient} handles via + * evict-and-retry. This matches the TS SDK's behavior. + */ +public final class McpConnectionManager implements AutoCloseable { + + private static final Logger log = LoggerFactory.getLogger(McpConnectionManager.class); + static final int MAX_CACHE_SIZE = 20; + private static final int STRIPE_COUNT = 32; + + private final LinkedHashMap cache = + new LinkedHashMap<>(16, 0.75f, true); + // Short-held lock for cache reads/writes; never held during I/O + private final ReentrantLock cacheLock = new ReentrantLock(); + // Fixed-size striped semaphore pool for connection establishment. + // Semaphores are virtual-thread-friendly (no carrier pinning) and + // the fixed pool eliminates the cleanup/race issues of per-key locks. + private final Semaphore[] connectStripes; + private final ConcurrentHashMap.KeySetView + knownStreamableKeys = ConcurrentHashMap.newKeySet(); + private final Duration connectTimeout; + private final Duration requestTimeout; + private final AdcpHttpClient adcpHttpClient; + private volatile boolean closed; + + public McpConnectionManager() { + this(Duration.ofSeconds(10)); + } + + public McpConnectionManager(Duration connectTimeout) { + this(connectTimeout, Duration.ofSeconds(30)); + } + + public McpConnectionManager(Duration connectTimeout, Duration requestTimeout) { + this(connectTimeout, requestTimeout, AdcpHttpClient.builder().build()); + } + + public McpConnectionManager(Duration connectTimeout, Duration requestTimeout, + AdcpHttpClient adcpHttpClient) { + this.connectTimeout = connectTimeout; + this.requestTimeout = requestTimeout; + this.adcpHttpClient = Objects.requireNonNull(adcpHttpClient, "adcpHttpClient"); + this.connectStripes = new Semaphore[STRIPE_COUNT]; + for (int i = 0; i < STRIPE_COUNT; i++) { + connectStripes[i] = new Semaphore(1); + } + } + + /** + * Gets or creates a cached MCP client connection. + * + *

On first connect, tries StreamableHTTP first. On non-401 failure, + * falls back to SSE for unknown endpoints. + * On 401, throws {@link AuthenticationRequiredError} immediately. + * + * @param agentUri the agent's base URI + * @param headers auth + extra headers to inject into MCP requests + * @param tokenHash hash of the auth token (for cache keying) + * @return a connected {@link McpSyncClient} + * @throws IllegalStateException if the manager has been closed + */ + public McpSyncClient getOrConnect(URI agentUri, Map headers, + String tokenHash) { + if (closed) { + throw new IllegalStateException("McpConnectionManager is closed"); + } + String cacheKey = agentUri + "::" + tokenHash; + + // Fast path: check cache under short lock (no I/O) + cacheLock.lock(); + try { + McpSyncClient existing = cache.get(cacheKey); + if (existing != null) { + return existing; + } + } finally { + cacheLock.unlock(); + } + + // Slow path: acquire striped semaphore so that only one thread + // connects per stripe. Different stripes proceed in parallel. + // Semaphore.acquire() is virtual-thread-friendly (no carrier pinning). + int stripe = (cacheKey.hashCode() & 0x7FFFFFFF) % STRIPE_COUNT; + Semaphore sem = connectStripes[stripe]; + try { + sem.acquire(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new ProtocolError("mcp", "Interrupted while connecting to " + agentUri, e); + } + try { + // Double-check after acquiring stripe semaphore + cacheLock.lock(); + try { + if (closed) { + throw new IllegalStateException("McpConnectionManager is closed"); + } + McpSyncClient existing = cache.get(cacheKey); + if (existing != null) { + return existing; + } + } finally { + cacheLock.unlock(); + } + + // Network I/O happens here — only blocks threads in the same stripe + McpSyncClient client = connectWithFallback(agentUri, headers, cacheKey); + + cacheLock.lock(); + try { + if (closed) { + closeQuietly(client); + throw new IllegalStateException("McpConnectionManager is closed"); + } + cache.put(cacheKey, client); + evictOldest(); + } finally { + cacheLock.unlock(); + } + return client; + } finally { + sem.release(); + } + } + + /** + * Evicts a specific connection from the cache. + */ + public void evict(URI agentUri, String tokenHash) { + String cacheKey = agentUri + "::" + tokenHash; + cacheLock.lock(); + try { + McpSyncClient evicted = cache.remove(cacheKey); + if (evicted != null) { + knownStreamableKeys.remove(cacheKey); + closeQuietly(evicted); + } + } finally { + cacheLock.unlock(); + } + } + + /** + * Evicts all cached connections for the given agent URI, regardless of + * token hash. Use this when auth credentials rotate so that stale + * connections with the old token don't linger until LRU eviction. + */ + public void invalidateForAgent(URI agentUri) { + String prefix = agentUri + "::"; + cacheLock.lock(); + try { + var it = cache.entrySet().iterator(); + while (it.hasNext()) { + var entry = it.next(); + if (entry.getKey().startsWith(prefix)) { + it.remove(); + knownStreamableKeys.remove(entry.getKey()); + closeQuietly(entry.getValue()); + } + } + } finally { + cacheLock.unlock(); + } + } + + @Override + public void close() { + cacheLock.lock(); + try { + closed = true; + cache.values().forEach(this::closeQuietly); + cache.clear(); + knownStreamableKeys.clear(); + } finally { + cacheLock.unlock(); + } + } + + private void evictOldest() { + while (cache.size() > MAX_CACHE_SIZE) { + var it = cache.entrySet().iterator(); + if (it.hasNext()) { + var entry = it.next(); + it.remove(); + knownStreamableKeys.remove(entry.getKey()); + closeQuietly(entry.getValue()); + } + } + } + + /** + * Probes the agent URI with a POST to determine whether it speaks + * StreamableHTTP (responds with {@code application/json} or + * {@code text/event-stream}) vs legacy SSE-only. Falls back to SSE + * when the probe gets a 4xx/non-JSON response. + * + *

This replaces the previous exception-based fallback which masked + * legitimate 5xx errors and double-charged every cold connect. + */ + private McpSyncClient connectWithFallback(URI agentUri, Map headers, + String cacheKey) { + String url = agentUri.toString(); + Map safe = sanitizeHeaders(headers); + + // Known-good StreamableHTTP endpoints skip the probe + if (knownStreamableKeys.contains(cacheKey)) { + try { + McpSyncClient client = buildAndInit(url, safe, true); + log.debug("Reconnected to {} via StreamableHTTP (cached)", agentUri); + return client; + } catch (Exception e) { + if (isAuthError(e)) { + throw probeAndBuildAuthError(agentUri, e); + } + // Lost contact — fall through to probe + knownStreamableKeys.remove(cacheKey); + log.debug("Cached StreamableHTTP failed for {}, re-probing", agentUri); + } + } + + // Probe: POST with MCP ping to detect transport type without + // accidentally completing an MCP initialize handshake. + boolean useStreamable = probeSupportsStreamableHttp(agentUri, safe); + + try { + McpSyncClient client = buildAndInit(url, safe, useStreamable); + if (useStreamable) { + knownStreamableKeys.add(cacheKey); + } + log.debug("Connected to {} via {}", agentUri, + useStreamable ? "StreamableHTTP" : "SSE"); + return client; + } catch (Exception e) { + if (isAuthError(e)) { + throw probeAndBuildAuthError(agentUri, e); + } + // If probe said StreamableHTTP but init failed, try SSE as last resort + if (useStreamable) { + log.debug("StreamableHTTP init failed despite probe, trying SSE for {}", + agentUri); + try { + McpSyncClient client = buildAndInit(url, safe, false); + log.debug("Connected to {} via SSE (fallback)", agentUri); + return client; + } catch (Exception e2) { + if (isAuthError(e2)) { + throw probeAndBuildAuthError(agentUri, e2); + } + e2.addSuppressed(e); + throw new ProtocolError("mcp", + "Failed to connect to " + agentUri + + " via StreamableHTTP and SSE", + e2); + } + } + throw new ProtocolError("mcp", + "Failed to connect to " + agentUri + " via SSE", e); + } + } + + /** + * Sends a POST probe to the agent URI to detect StreamableHTTP support. + * StreamableHTTP endpoints respond to POST with {@code application/json} + * or {@code text/event-stream} content-type. Legacy SSE endpoints + * typically return 404/405 on POST to the root. + * + * @return true if the endpoint appears to support StreamableHTTP + */ + private boolean probeSupportsStreamableHttp(URI agentUri, Map headers) { + String pingPayload = "{\"jsonrpc\":\"2.0\",\"method\":\"ping\"," + + "\"id\":\"probe\",\"params\":{}}"; + Map probeHeaders = new LinkedHashMap<>(headers); + probeHeaders.put("Content-Type", "application/json"); + probeHeaders.put("Accept", "application/json, text/event-stream"); + try { + AdcpHttpResponse resp = adcpHttpClient.post( + agentUri, + probeHeaders, + pingPayload.getBytes(StandardCharsets.UTF_8)); + String ct = resp.headers().firstValue("Content-Type").orElse(""); + if (resp.statusCode() >= 200 && resp.statusCode() < 300) { + return ct.contains("application/json") + || ct.contains("text/event-stream"); + } + return false; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + log.debug("StreamableHTTP probe interrupted for {}: {}", agentUri, e.getMessage()); + return true; + } catch (IOException e) { + log.debug("StreamableHTTP probe failed for {}: {}", agentUri, e.getMessage()); + return true; + } + } + + private McpSyncClient buildAndInit(String url, Map headers, + boolean useStreamable) { + var reqBuilder = java.net.http.HttpRequest.newBuilder().timeout(requestTimeout); + McpClientTransport transport = useStreamable + ? HttpClientStreamableHttpTransport.builder(url) + .connectTimeout(connectTimeout) + .requestBuilder(reqBuilder) + .clientBuilder(adcpHttpClient.newMcpClientBuilder()) + .httpRequestCustomizer((rb, method, uri, body, ctx) -> + headers.forEach(rb::header)) + .build() + : HttpClientSseClientTransport.builder(url) + .connectTimeout(connectTimeout) + .requestBuilder(reqBuilder) + .clientBuilder(adcpHttpClient.newMcpClientBuilder()) + .httpRequestCustomizer((rb, method, uri, body, ctx) -> + headers.forEach(rb::header)) + .build(); + McpSyncClient client = McpClient.sync(transport).build(); + try { + client.initialize(); + return client; + } catch (Exception e) { + closeQuietly(client); + throw e; + } + } + + private static Map sanitizeHeaders(Map headers) { + Map sanitized = new LinkedHashMap<>(); + for (var entry : headers.entrySet()) { + String name = entry.getKey(); + String value = entry.getValue(); + if (org.adcontextprotocol.adcp.http.ProtectedHeaders.isProtected(name)) { + log.debug("Skipping protected MCP header: {}", name); + continue; + } + if (hasCrlf(name) || hasCrlf(value)) { + log.warn("Rejecting MCP header with CR/LF characters: {}", name); + continue; + } + sanitized.put(name, value); + } + return sanitized; + } + + private static boolean hasCrlf(String s) { + return s.indexOf('\r') >= 0 || s.indexOf('\n') >= 0; + } + + // NOTE: MCP SDK 1.1.2 does not expose HTTP response headers on errors. + // We work around this by sending a HEAD probe to the agent URI to + // retrieve the WWW-Authenticate challenge. When the MCP SDK adds + // response header access, this probe can be replaced with direct + // header inspection. + + /** + * Probes the agent URI with a HEAD request to retrieve WWW-Authenticate. + * If the probe fails (e.g. network error, non-401 response), returns + * an AuthenticationRequiredError with challenge=null. + */ + private AuthenticationRequiredError probeAndBuildAuthError(URI agentUri, Exception cause) { + AuthChallengeInfo challenge = null; + try { + AdcpHttpResponse resp = adcpHttpClient.send("HEAD", agentUri, Map.of(), null); + challenge = parseAuthChallenge(resp); + if (challenge == null && resp.statusCode() == 405) { + challenge = parseAuthChallenge( + adcpHttpClient.send("OPTIONS", agentUri, Map.of(), null)); + } + } catch (InterruptedException probeEx) { + Thread.currentThread().interrupt(); + log.debug("Auth challenge probe interrupted for {}: {}", + agentUri, probeEx.getMessage()); + } catch (IOException probeEx) { + log.debug("Auth challenge probe failed for {}: {}", + agentUri, probeEx.getMessage()); + } + return new AuthenticationRequiredError(agentUri, challenge, null, cause); + } + + private static @Nullable AuthChallengeInfo parseAuthChallenge(AdcpHttpResponse response) { + if (response.statusCode() != 401) { + return null; + } + String wwwAuth = response.headers().firstValue("WWW-Authenticate").orElse(null); + return WwwAuthenticateParser.parse(wwwAuth); + } + + private boolean isAuthError(Exception e) { + for (Throwable t = e; t != null; t = t.getCause()) { + String msg = t.getMessage(); + if (msg == null) continue; + // Check MCP SDK's error type first + if (t instanceof McpError) { + if (isAuthMessage(msg)) return true; + } + if (isAuthMessage(msg)) return true; + } + return false; + } + + /** Word-bounded 401 matching to avoid false positives like "401234". */ + private static final java.util.regex.Pattern AUTH_401_PATTERN = + java.util.regex.Pattern.compile("\\b401\\b"); + + private static boolean isAuthMessage(String msg) { + return AUTH_401_PATTERN.matcher(msg).find() + || msg.contains("Unauthorized"); + } + + private void closeQuietly(McpSyncClient client) { + try { + if (client != null) { + client.close(); + } + } catch (Exception e) { + log.debug("Error closing MCP client: {}", e.getMessage()); + } + } +} diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/transport/mcp/package-info.java b/adcp/src/main/java/org/adcontextprotocol/adcp/transport/mcp/package-info.java new file mode 100644 index 0000000..6bf5734 --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/transport/mcp/package-info.java @@ -0,0 +1,2 @@ +@org.jspecify.annotations.NullMarked +package org.adcontextprotocol.adcp.transport.mcp; diff --git a/adcp/src/main/java/org/adcontextprotocol/adcp/transport/package-info.java b/adcp/src/main/java/org/adcontextprotocol/adcp/transport/package-info.java new file mode 100644 index 0000000..307e97e --- /dev/null +++ b/adcp/src/main/java/org/adcontextprotocol/adcp/transport/package-info.java @@ -0,0 +1,2 @@ +@org.jspecify.annotations.NullMarked +package org.adcontextprotocol.adcp.transport; diff --git a/adcp/src/test/java/org/adcontextprotocol/adcp/AdcpClientTest.java b/adcp/src/test/java/org/adcontextprotocol/adcp/AdcpClientTest.java new file mode 100644 index 0000000..126ef9c --- /dev/null +++ b/adcp/src/test/java/org/adcontextprotocol/adcp/AdcpClientTest.java @@ -0,0 +1,134 @@ +package org.adcontextprotocol.adcp; + +import org.adcontextprotocol.adcp.error.ConfigurationError; +import org.adcontextprotocol.adcp.http.SsrfPolicy; +import org.junit.jupiter.api.Test; + +import java.net.URI; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Tests for {@link AdcpClient} builder and lifecycle. + */ +class AdcpClientTest { + + private static final URI AGENT_URI = URI.create("https://agent.example.com"); + + @Test + void builder_creates_client() { + try (AdcpClient client = AdcpClient.builder() + .agent(AgentConfig.mcp("test", AGENT_URI)) + .build()) { + assertNotNull(client); + assertEquals("test", client.agent().id()); + assertEquals(AGENT_URI, client.agent().agentUri()); + } + } + + @Test + void builder_rejects_missing_agent() { + assertThrows(ConfigurationError.class, () -> + AdcpClient.builder().build()); + } + + @Test + void builder_with_version() { + try (AdcpClient client = AdcpClient.builder() + .agent(AgentConfig.mcp("test", AGENT_URI)) + .adcpVersion(AdcpVersion.V3_1) + .build()) { + assertEquals(AdcpVersion.V3_1, client.adcpVersion()); + } + } + + @Test + void builder_with_permissive_ssrf() { + // Should not throw — permissive policy allows localhost + try (AdcpClient client = AdcpClient.builder() + .agent(AgentConfig.mcp("test", + URI.create("http://localhost:8080"))) + .ssrfPolicy(SsrfPolicy.permissive()) + .build()) { + assertNotNull(client); + } + } + + @Test + void client_is_autocloseable() { + AdcpClient client = AdcpClient.builder() + .agent(AgentConfig.mcp("test", AGENT_URI)) + .build(); + assertDoesNotThrow(client::close); + } + + @Test + void close_is_idempotent() { + AdcpClient client = AdcpClient.builder() + .agent(AgentConfig.mcp("test", AGENT_URI)) + .build(); + client.close(); + assertDoesNotThrow(client::close); + } + + @Test + void a2a_protocol_rejected_at_call_time() { + AgentConfig a2aAgent = AgentConfig.builder() + .id("a2a") + .agentUri(AGENT_URI) + .protocol(Protocol.A2A) + .build(); + // A2A rejection happens at callTool dispatch (ProtocolClient) + try (AdcpClient client = AdcpClient.builder() + .agent(a2aAgent) + .ssrfPolicy(SsrfPolicy.permissive()) + .build()) { + var ex = assertThrows(org.adcontextprotocol.adcp.error.FeatureUnsupportedError.class, + () -> client.callTool("get_products", + java.util.Map.of(), java.util.Map.class)); + assertTrue(ex.getMessage().contains("A2A")); + } + } + + @Test + void callTool_accepts_null_args_without_npe() { + // Null args should be treated as empty map, not throw NPE. + // The call will fail at transport (no server), but the null-guard + // in callTool must normalise to Map.of() before that point. + AgentConfig a2aAgent = AgentConfig.builder() + .id("a2a") + .agentUri(AGENT_URI) + .protocol(Protocol.A2A) + .build(); + try (AdcpClient client = AdcpClient.builder() + .agent(a2aAgent) + .ssrfPolicy(SsrfPolicy.permissive()) + .build()) { + // A2A rejection fires before any null-arg handling, proving + // the call doesn't NPE on null args. + assertThrows(org.adcontextprotocol.adcp.error.FeatureUnsupportedError.class, + () -> client.callTool("get_products", null, java.util.Map.class)); + } + } + + @Test + void builder_accepts_string_version() { + try (AdcpClient client = AdcpClient.builder() + .agent(AgentConfig.mcp("test", AGENT_URI)) + .adcpVersion("3.0") + .build()) { + assertNotNull(client.adcpVersion()); + assertEquals(3, client.adcpVersion().majorVersion()); + assertEquals("3.0", client.adcpVersion().minorVersion()); + } + } + + @Test + void builder_rejects_cross_major_version() { + assertThrows(org.adcontextprotocol.adcp.error.ConfigurationError.class, + () -> AdcpClient.builder() + .agent(AgentConfig.mcp("test", AGENT_URI)) + .adcpVersion(new AdcpVersion(2, null)) + .build()); + } +} diff --git a/adcp/src/test/java/org/adcontextprotocol/adcp/AdcpVersionTest.java b/adcp/src/test/java/org/adcontextprotocol/adcp/AdcpVersionTest.java new file mode 100644 index 0000000..51ec487 --- /dev/null +++ b/adcp/src/test/java/org/adcontextprotocol/adcp/AdcpVersionTest.java @@ -0,0 +1,123 @@ +package org.adcontextprotocol.adcp; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Tests for {@link AdcpVersion}. + */ +class AdcpVersionTest { + + @Test + void v3_constants() { + assertEquals(3, AdcpVersion.V3.majorVersion()); + assertNull(AdcpVersion.V3.minorVersion()); + } + + @Test + void v3_1_constants() { + assertEquals(3, AdcpVersion.V3_1.majorVersion()); + assertEquals("3.1", AdcpVersion.V3_1.minorVersion()); + } + + @Test + void rejects_zero_major_version() { + assertThrows(IllegalArgumentException.class, + () -> new AdcpVersion(0, null)); + } + + @Test + void rejects_negative_major_version() { + assertThrows(IllegalArgumentException.class, + () -> new AdcpVersion(-1, null)); + } + + @Test + void custom_version() { + var v = new AdcpVersion(4, "4.2"); + assertEquals(4, v.majorVersion()); + assertEquals("4.2", v.minorVersion()); + } + + @Test + void rejects_mismatched_minor_version() { + assertThrows(IllegalArgumentException.class, + () -> new AdcpVersion(3, "4.1"), + "minorVersion must start with majorVersion"); + } + + @Test + void allows_null_minor_version() { + var v = new AdcpVersion(5, null); + assertEquals(5, v.majorVersion()); + assertNull(v.minorVersion()); + } + + @Test + void rejects_minor_version_with_invalid_characters() { + assertThrows(IllegalArgumentException.class, + () -> new AdcpVersion(3, "3.\nFake-Log-Entry"), + "minorVersion must be a version string"); + } + + @Test + void rejects_minor_version_too_long() { + assertThrows(IllegalArgumentException.class, + () -> new AdcpVersion(3, "3.1234567890123456789"), + "minorVersion too long"); + } + + @Test + void accepts_three_part_minor_version() { + var v = new AdcpVersion(3, "3.1.2"); + assertEquals("3.1.2", v.minorVersion()); + } + + // -- AdcpVersion.of(String) -- + + @Test + void of_parses_release_precision_version() { + AdcpVersion v = AdcpVersion.of("3.0"); + assertEquals(3, v.majorVersion()); + assertEquals("3.0", v.minorVersion()); + } + + @Test + void of_parses_minor_version() { + AdcpVersion v = AdcpVersion.of("3.1"); + assertEquals(3, v.majorVersion()); + assertEquals("3.1", v.minorVersion()); + } + + @Test + void of_rejects_major_only_string() { + assertThrows(IllegalArgumentException.class, () -> AdcpVersion.of("3")); + } + + @Test + void of_rejects_non_numeric() { + assertThrows(IllegalArgumentException.class, () -> AdcpVersion.of("abc.def")); + } + + @Test + void of_rejects_null() { + assertThrows(NullPointerException.class, () -> AdcpVersion.of(null)); + } + + // -- AdcpSdkVersion constants (build-time generated) -- + + @Test + void sdk_major_version_is_positive() { + assertTrue(AdcpSdkVersion.SDK_MAJOR_VERSION > 0, + "SDK_MAJOR_VERSION must be a positive integer"); + } + + @Test + void sdk_release_version_matches_major() { + String release = AdcpSdkVersion.SDK_RELEASE_VERSION; + assertTrue(release.startsWith(AdcpSdkVersion.SDK_MAJOR_VERSION + "."), + "SDK_RELEASE_VERSION must start with SDK_MAJOR_VERSION: " + release); + } +} + diff --git a/adcp/src/test/java/org/adcontextprotocol/adcp/AgentConfigTest.java b/adcp/src/test/java/org/adcontextprotocol/adcp/AgentConfigTest.java new file mode 100644 index 0000000..10c6259 --- /dev/null +++ b/adcp/src/test/java/org/adcontextprotocol/adcp/AgentConfigTest.java @@ -0,0 +1,225 @@ +package org.adcontextprotocol.adcp; + +import org.adcontextprotocol.adcp.auth.BasicCredentials; +import org.adcontextprotocol.adcp.auth.OAuthClientCredentials; +import org.adcontextprotocol.adcp.auth.OAuthTokens; +import org.adcontextprotocol.adcp.error.ConfigurationError; +import org.junit.jupiter.api.Test; + +import java.net.URI; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Tests for {@link AgentConfig}. + */ +class AgentConfigTest { + + private static final URI AGENT_URI = URI.create("https://agent.example.com"); + + @Test + void builder_creates_minimal_config() { + AgentConfig config = AgentConfig.builder() + .id("test-agent") + .agentUri(AGENT_URI) + .build(); + + assertEquals("test-agent", config.id()); + assertEquals(AGENT_URI, config.agentUri()); + assertEquals(Protocol.MCP, config.protocol()); + assertNull(config.authToken()); + assertNull(config.basicAuth()); + assertTrue(config.extraHeaders().isEmpty()); + } + + @Test + void static_factory_mcp_no_auth() { + AgentConfig config = AgentConfig.mcp("a", AGENT_URI); + + assertEquals("a", config.id()); + assertEquals(Protocol.MCP, config.protocol()); + assertNull(config.authToken()); + } + + @Test + void static_factory_mcp_with_token() { + AgentConfig config = AgentConfig.mcp("a", AGENT_URI, "my-token"); + + assertEquals("my-token", config.authToken()); + } + + @Test + void builder_with_bearer_token() { + AgentConfig config = AgentConfig.builder() + .id("agent") + .agentUri(AGENT_URI) + .authToken("test-bearer-token") + .build(); + + assertEquals("test-bearer-token", config.authToken()); + assertNull(config.basicAuth()); + } + + @Test + void builder_with_basic_auth() { + AgentConfig config = AgentConfig.builder() + .id("agent") + .agentUri(AGENT_URI) + .basicAuth(new BasicCredentials("user", "pass")) + .build(); + + assertNotNull(config.basicAuth()); + assertEquals("user", config.basicAuth().username()); + } + + @Test + void builder_rejects_multiple_auth() { + assertThrows(ConfigurationError.class, () -> + AgentConfig.builder() + .id("agent") + .agentUri(AGENT_URI) + .authToken("tok") + .basicAuth(new BasicCredentials("u", "p")) + .build()); + } + + @Test + void builder_rejects_missing_id() { + assertThrows(ConfigurationError.class, () -> + AgentConfig.builder() + .agentUri(AGENT_URI) + .build()); + } + + @Test + void builder_rejects_missing_agent_uri() { + assertThrows(ConfigurationError.class, () -> + AgentConfig.builder() + .id("agent") + .build()); + } + + @Test + void extra_headers_are_immutable() { + var headers = new java.util.HashMap(); + headers.put("X-Custom", "value"); + + AgentConfig config = AgentConfig.builder() + .id("agent") + .agentUri(AGENT_URI) + .extraHeaders(headers) + .build(); + + // Modifying the original map doesn't affect the config + headers.put("X-New", "val"); + assertFalse(config.extraHeaders().containsKey("X-New")); + + // The returned map is also immutable + assertThrows(UnsupportedOperationException.class, + () -> config.extraHeaders().put("X-Fail", "val")); + } + + @Test + void builder_with_a2a_protocol() { + AgentConfig config = AgentConfig.builder() + .id("a2a-agent") + .agentUri(AGENT_URI) + .protocol(Protocol.A2A) + .build(); + + assertEquals(Protocol.A2A, config.protocol()); + } + + @Test + void builder_with_oauth_client_credentials() { + var oauthCC = new OAuthClientCredentials( + "client-id", "client-secret", + "https://auth.example.com/token", "read write"); + + AgentConfig config = AgentConfig.builder() + .id("agent") + .agentUri(AGENT_URI) + .oauthClientCredentials(oauthCC) + .build(); + + assertNotNull(config.oauthClientCredentials()); + assertEquals("client-id", config.oauthClientCredentials().clientId()); + } + + @Test + void builder_with_adcp_version() { + AgentConfig config = AgentConfig.builder() + .id("agent") + .agentUri(AGENT_URI) + .adcpVersion(AdcpVersion.V3_1) + .build(); + + assertNotNull(config.adcpVersion()); + assertEquals(3, config.adcpVersion().majorVersion()); + assertEquals("3.1", config.adcpVersion().minorVersion()); + } + + @Test + void toString_redacts_authToken_and_webhookSecret() { + AgentConfig config = AgentConfig.builder() + .id("agent") + .agentUri(AGENT_URI) + .authToken("super-secret-token") + .webhookSecret("hmac-secret-key") + .build(); + + String str = config.toString(); + assertFalse(str.contains("super-secret-token"), + "toString() must not contain authToken value"); + assertFalse(str.contains("hmac-secret-key"), + "toString() must not contain webhookSecret value"); + assertTrue(str.contains(""), + "toString() should show for secrets"); + assertTrue(str.contains("agent"), + "toString() should still show the agent id"); + } + + @Test + void authToken_rejects_crlf() { + assertThrows(ConfigurationError.class, () -> + AgentConfig.mcp("a", AGENT_URI, "token\r\nX-Injected: bad")); + assertThrows(ConfigurationError.class, () -> + AgentConfig.mcp("a", AGENT_URI, "token\ninjection")); + } + + @Test + void toString_redacts_extraHeaders_values() { + AgentConfig config = AgentConfig.builder() + .id("agent") + .agentUri(AGENT_URI) + .extraHeaders(Map.of("X-Api-Key", "secret-key-value")) + .build(); + + String str = config.toString(); + assertFalse(str.contains("secret-key-value"), + "toString() must not contain extra header values"); + assertTrue(str.contains("<1 headers>"), + "toString() should show header count"); + } + + @Test + void rejects_extraHeaders_with_crlf_in_key() { + assertThrows(org.adcontextprotocol.adcp.error.ConfigurationError.class, + () -> AgentConfig.builder() + .id("a1") + .agentUri(AGENT_URI) + .extraHeaders(Map.of("X-Bad\rKey", "value")) + .build()); + } + + @Test + void rejects_extraHeaders_with_crlf_in_value() { + assertThrows(org.adcontextprotocol.adcp.error.ConfigurationError.class, + () -> AgentConfig.builder() + .id("a1") + .agentUri(AGENT_URI) + .extraHeaders(Map.of("X-Key", "bad\nvalue")) + .build()); + } +} diff --git a/adcp/src/test/java/org/adcontextprotocol/adcp/auth/AuthTokenResolverTest.java b/adcp/src/test/java/org/adcontextprotocol/adcp/auth/AuthTokenResolverTest.java new file mode 100644 index 0000000..d109acc --- /dev/null +++ b/adcp/src/test/java/org/adcontextprotocol/adcp/auth/AuthTokenResolverTest.java @@ -0,0 +1,84 @@ +package org.adcontextprotocol.adcp.auth; + +import org.adcontextprotocol.adcp.AgentConfig; +import org.adcontextprotocol.adcp.Protocol; +import org.junit.jupiter.api.Test; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.Base64; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Tests for {@link AuthTokenResolver}. + */ +class AuthTokenResolverTest { + + private static final URI AGENT_URI = URI.create("https://agent.example.com"); + + @Test + void resolve_bearer_token() { + AgentConfig config = AgentConfig.builder() + .id("a") + .agentUri(AGENT_URI) + .authToken("my-token") + .build(); + + Map headers = AuthTokenResolver.resolve(config); + + assertEquals("Bearer my-token", headers.get("Authorization")); + assertEquals("my-token", headers.get("x-adcp-auth")); + } + + @Test + void resolve_basic_auth() { + AgentConfig config = AgentConfig.builder() + .id("a") + .agentUri(AGENT_URI) + .basicAuth(new BasicCredentials("user", "pass")) + .build(); + + Map headers = AuthTokenResolver.resolve(config); + + String expected = "Basic " + Base64.getEncoder().encodeToString( + "user:pass".getBytes(StandardCharsets.UTF_8)); + assertEquals(expected, headers.get("Authorization")); + assertFalse(headers.containsKey("x-adcp-auth")); + } + + @Test + void resolve_oauth_tokens() { + AgentConfig config = AgentConfig.builder() + .id("a") + .agentUri(AGENT_URI) + .oauthTokens(OAuthTokens.bearer("access-tok-123")) + .build(); + + Map headers = AuthTokenResolver.resolve(config); + + assertEquals("Bearer access-tok-123", headers.get("Authorization")); + assertFalse(headers.containsKey("x-adcp-auth")); + } + + @Test + void resolve_no_auth_returns_empty() { + AgentConfig config = AgentConfig.mcp("a", AGENT_URI); + + Map headers = AuthTokenResolver.resolve(config); + + assertTrue(headers.isEmpty()); + } + + @Test + void resolved_headers_are_immutable() { + AgentConfig config = AgentConfig.mcp("a", AGENT_URI, "tok"); + + Map headers = AuthTokenResolver.resolve(config); + + assertThrows(UnsupportedOperationException.class, + () -> headers.put("X-Evil", "val")); + } +} diff --git a/adcp/src/test/java/org/adcontextprotocol/adcp/auth/CredentialsTest.java b/adcp/src/test/java/org/adcontextprotocol/adcp/auth/CredentialsTest.java new file mode 100644 index 0000000..1159f6b --- /dev/null +++ b/adcp/src/test/java/org/adcontextprotocol/adcp/auth/CredentialsTest.java @@ -0,0 +1,131 @@ +package org.adcontextprotocol.adcp.auth; + +import org.junit.jupiter.api.Test; + +import java.time.Instant; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Tests for auth credential records. + */ +class CredentialsTest { + + @Test + void basicCredentials_validates() { + var creds = new BasicCredentials("user", "pass"); + assertEquals("user", creds.username()); + assertEquals("pass", creds.password()); + } + + @Test + void basicCredentials_rejects_blank_username() { + assertThrows(IllegalArgumentException.class, + () -> new BasicCredentials("", "pass")); + assertThrows(IllegalArgumentException.class, + () -> new BasicCredentials(" ", "pass")); + } + + @Test + void basicCredentials_allows_blank_password() { + // Blank passwords are valid — many platforms use username=token, password="" + var creds = new BasicCredentials("user", ""); + assertEquals("", creds.password()); + } + + @Test + void basicCredentials_rejects_colon_in_username() { + assertThrows(IllegalArgumentException.class, + () -> new BasicCredentials("us:er", "pass")); + } + + @Test + void basicCredentials_toString_redacts_password() { + var creds = new BasicCredentials("user", "secret"); + String str = creds.toString(); + assertTrue(str.contains("user")); + assertFalse(str.contains("secret")); + assertTrue(str.contains("")); + } + + @Test + void basicCredentials_rejects_null() { + assertThrows(NullPointerException.class, + () -> new BasicCredentials(null, "pass")); + assertThrows(NullPointerException.class, + () -> new BasicCredentials("user", null)); + } + + @Test + void oauthClientCredentials_validates() { + var cc = new OAuthClientCredentials( + "id", "secret", "https://auth.example.com/token", "read"); + assertEquals("id", cc.clientId()); + assertEquals("read", cc.scope()); + } + + @Test + void oauthClientCredentials_rejects_blank() { + assertThrows(IllegalArgumentException.class, + () -> new OAuthClientCredentials("", "s", "t", null)); + assertThrows(IllegalArgumentException.class, + () -> new OAuthClientCredentials("id", "", "t", null)); + } + + @Test + void oauthClientCredentials_toString_redacts_secret() { + var cc = new OAuthClientCredentials( + "my-id", "super-secret", "https://auth.example.com/token", "read"); + String str = cc.toString(); + assertTrue(str.contains("my-id")); + assertFalse(str.contains("super-secret")); + assertTrue(str.contains("")); + } + + @Test + void oauthTokens_bearer_factory() { + var tokens = OAuthTokens.bearer("access-123"); + assertEquals("access-123", tokens.accessToken()); + assertEquals("Bearer", tokens.tokenType()); + assertNull(tokens.refreshToken()); + assertFalse(tokens.isExpired()); + } + + @Test + void oauthTokens_expired() { + var tokens = OAuthTokens.bearer( + "access", "refresh", Instant.now().minusSeconds(60)); + assertTrue(tokens.isExpired()); + } + + @Test + void oauthTokens_not_expired() { + var tokens = OAuthTokens.bearer( + "access", "refresh", Instant.now().plusSeconds(300)); + assertFalse(tokens.isExpired()); + } + + @Test + void oauthTokens_rejects_blank_access_token() { + assertThrows(IllegalArgumentException.class, + () -> OAuthTokens.bearer("")); + } + + @Test + void oauthTokens_toString_redacts_tokens() { + var tokens = OAuthTokens.bearer("access-secret", "refresh-secret", + Instant.now().plusSeconds(300)); + String str = tokens.toString(); + assertFalse(str.contains("access-secret")); + assertFalse(str.contains("refresh-secret")); + assertTrue(str.contains("")); + } + + @Test + void oauth_access_token_rejects_crlf() { + assertThrows(IllegalArgumentException.class, + () -> OAuthTokens.bearer("token\r\nX-Injected: bad")); + assertThrows(IllegalArgumentException.class, + () -> OAuthTokens.bearer("token\ninjection")); + } +} diff --git a/adcp/src/test/java/org/adcontextprotocol/adcp/auth/WwwAuthenticateParserTest.java b/adcp/src/test/java/org/adcontextprotocol/adcp/auth/WwwAuthenticateParserTest.java new file mode 100644 index 0000000..584c0ea --- /dev/null +++ b/adcp/src/test/java/org/adcontextprotocol/adcp/auth/WwwAuthenticateParserTest.java @@ -0,0 +1,99 @@ +package org.adcontextprotocol.adcp.auth; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Tests for {@link WwwAuthenticateParser}. + */ +class WwwAuthenticateParserTest { + + @Test + void parses_bearer_with_realm() { + AuthChallengeInfo info = WwwAuthenticateParser.parse( + "Bearer realm=\"example\""); + + assertNotNull(info); + assertEquals("bearer", info.scheme()); + assertEquals("example", info.realm()); + assertNull(info.scope()); + assertNull(info.error()); + } + + @Test + void parses_bearer_with_error() { + AuthChallengeInfo info = WwwAuthenticateParser.parse( + "Bearer realm=\"api\", error=\"invalid_token\", " + + "error_description=\"Token expired\""); + + assertNotNull(info); + assertEquals("bearer", info.scheme()); + assertEquals("api", info.realm()); + assertEquals("invalid_token", info.error()); + assertEquals("Token expired", info.errorDescription()); + } + + @Test + void parses_bearer_with_scope() { + AuthChallengeInfo info = WwwAuthenticateParser.parse( + "Bearer scope=\"read write\""); + + assertNotNull(info); + assertEquals("bearer", info.scheme()); + assertEquals("read write", info.scope()); + } + + @Test + void parses_basic_with_realm() { + AuthChallengeInfo info = WwwAuthenticateParser.parse( + "Basic realm=\"Agent Admin\""); + + assertNotNull(info); + assertEquals("basic", info.scheme()); + assertEquals("Agent Admin", info.realm()); + } + + @Test + void parses_scheme_only() { + AuthChallengeInfo info = WwwAuthenticateParser.parse("Bearer"); + + assertNotNull(info); + assertEquals("bearer", info.scheme()); + assertNull(info.realm()); + } + + @Test + void scheme_is_lowercased() { + AuthChallengeInfo info = WwwAuthenticateParser.parse("BEARER realm=\"x\""); + + assertNotNull(info); + assertEquals("bearer", info.scheme()); + } + + @Test + void returns_null_for_blank() { + assertNull(WwwAuthenticateParser.parse(null)); + assertNull(WwwAuthenticateParser.parse("")); + assertNull(WwwAuthenticateParser.parse(" ")); + } + + @Test + void parses_unquoted_values() { + AuthChallengeInfo info = WwwAuthenticateParser.parse( + "Bearer realm=example, error=invalid_token"); + + assertNotNull(info); + assertEquals("example", info.realm()); + assertEquals("invalid_token", info.error()); + } + + @Test + void authChallengeInfo_lowercases_scheme() { + AuthChallengeInfo info = new AuthChallengeInfo("Bearer", null, null, null, null); + assertEquals("bearer", info.scheme(), + "Scheme should be lowercased in the constructor"); + } +} diff --git a/adcp/src/test/java/org/adcontextprotocol/adcp/error/AdcpErrorTest.java b/adcp/src/test/java/org/adcontextprotocol/adcp/error/AdcpErrorTest.java new file mode 100644 index 0000000..f24a7d5 --- /dev/null +++ b/adcp/src/test/java/org/adcontextprotocol/adcp/error/AdcpErrorTest.java @@ -0,0 +1,178 @@ +package org.adcontextprotocol.adcp.error; + +import org.adcontextprotocol.adcp.auth.AuthChallengeInfo; +import org.adcontextprotocol.adcp.auth.OAuthMetadataInfo; +import org.junit.jupiter.api.Test; + +import java.net.URI; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Tests for the {@link AdcpError} sealed hierarchy. + */ +class AdcpErrorTest { + + @Test + void protocolError_carries_protocol_and_cause() { + var cause = new RuntimeException("transport failed"); + var error = new ProtocolError("mcp", "MCP call failed", cause); + + assertEquals("PROTOCOL_ERROR", error.code()); + assertEquals("mcp", error.protocol()); + assertSame(cause, error.getCause()); + } + + @Test + void authenticationRequiredError_carries_challenge() { + var challenge = new AuthChallengeInfo("bearer", "example", null, null, null); + var error = new AuthenticationRequiredError( + URI.create("https://agent.example.com"), challenge, null); + + assertEquals("AUTHENTICATION_REQUIRED", error.code()); + assertEquals("bearer", error.suggestedScheme()); + assertFalse(error.hasOAuth()); + assertNotNull(error.challenge()); + } + + @Test + void authenticationRequiredError_with_oauth() { + var oauth = new OAuthMetadataInfo( + "https://auth.example.com/authorize", + "https://auth.example.com/token", + null, null); + var error = new AuthenticationRequiredError( + URI.create("https://agent.example.com"), null, oauth); + + assertTrue(error.hasOAuth()); + assertNull(error.suggestedScheme()); + assertEquals("https://auth.example.com/token", + error.oauthMetadata().tokenEndpoint()); + } + + @Test + void taskTimeoutError_carries_details() { + var error = new TaskTimeoutError("task-123", 120000); + + assertEquals("TASK_TIMEOUT", error.code()); + assertEquals("task-123", error.taskId()); + assertEquals(120000, error.timeoutMs()); + assertTrue(error.getMessage().contains("120000")); + } + + @Test + void taskAbortedError() { + var error = new TaskAbortedError("task-456", "client cancelled"); + + assertEquals("TASK_ABORTED", error.code()); + assertEquals("task-456", error.taskId()); + } + + @Test + void deferredTaskError_carries_token() { + var error = new DeferredTaskError("defer-token-789"); + + assertEquals("TASK_DEFERRED", error.code()); + assertEquals("defer-token-789", error.token()); + } + + @Test + void validationError() { + var error = new ValidationError("Invalid field value", "brief"); + + assertEquals("VALIDATION_ERROR", error.code()); + assertEquals(java.util.List.of("brief"), error.path()); + } + + @Test + void configurationError() { + var error = new ConfigurationError("Missing agent URI", "agentUri"); + + assertEquals("CONFIGURATION_ERROR", error.code()); + assertEquals("agentUri", error.configField()); + } + + @Test + void versionUnsupportedError() { + var error = new VersionUnsupportedError( + "get_products", "version", "2.5", + URI.create("https://agent.example.com")); + + assertEquals("VERSION_UNSUPPORTED", error.code()); + assertEquals("get_products", error.taskType()); + assertEquals("version", error.reason()); + } + + @Test + void agentNotFoundError() { + var error = new AgentNotFoundError("sales", List.of("marketing", "ops")); + + assertEquals("AGENT_NOT_FOUND", error.code()); + assertEquals("sales", error.agentId()); + assertEquals(List.of("marketing", "ops"), error.availableAgents()); + } + + @Test + void unsupportedTaskError() { + var error = new UnsupportedTaskError("get_products"); + + assertEquals("UNSUPPORTED_TASK", error.code()); + assertEquals("get_products", error.taskName()); + } + + @Test + void featureUnsupportedError() { + var error = new FeatureUnsupportedError( + List.of("webhooks"), List.of("products")); + + assertEquals("FEATURE_UNSUPPORTED", error.code()); + assertEquals(List.of("webhooks"), error.unsupportedFeatures()); + } + + @Test + void responseTooLargeError() { + var error = new ResponseTooLargeError( + 4096, 50000, URI.create("https://agent.example.com")); + + assertEquals("RESPONSE_TOO_LARGE", error.code()); + assertEquals(4096, error.limit()); + assertEquals(50000, error.bytesRead()); + } + + @Test + void idempotencyErrors() { + var conflict = new IdempotencyConflictError("key already in use"); + assertEquals("IDEMPOTENCY_CONFLICT", conflict.code()); + + var expired = new IdempotencyExpiredError("key TTL exceeded"); + assertEquals("IDEMPOTENCY_EXPIRED", expired.code()); + } + + @Test + void all_errors_extend_adcpError() { + // Verify the sealed hierarchy — all subclasses are AdcpError + assertInstanceOf(AdcpError.class, new ProtocolError("mcp", "test", null)); + assertInstanceOf(AdcpError.class, new AuthenticationRequiredError( + URI.create("https://a.com"), null, null)); + assertInstanceOf(AdcpError.class, new TaskTimeoutError(null, 1000)); + assertInstanceOf(AdcpError.class, new TaskAbortedError("t", null)); + assertInstanceOf(AdcpError.class, new DeferredTaskError("t")); + assertInstanceOf(AdcpError.class, new ValidationError("m", null)); + assertInstanceOf(AdcpError.class, new ConfigurationError("m", null)); + assertInstanceOf(AdcpError.class, new VersionUnsupportedError(null, "r", null, null)); + assertInstanceOf(AdcpError.class, new AgentNotFoundError("a", List.of())); + assertInstanceOf(AdcpError.class, new UnsupportedTaskError("t")); + assertInstanceOf(AdcpError.class, new FeatureUnsupportedError(List.of(), List.of())); + assertInstanceOf(AdcpError.class, new ResponseTooLargeError(1, 2, null)); + assertInstanceOf(AdcpError.class, new IdempotencyConflictError("m")); + assertInstanceOf(AdcpError.class, new IdempotencyExpiredError("m")); + } + + @Test + void all_errors_are_unchecked() { + // AdcpError extends RuntimeException — callers don't need try/catch + assertInstanceOf(RuntimeException.class, + new ProtocolError("mcp", "test", null)); + } +} diff --git a/adcp/src/test/java/org/adcontextprotocol/adcp/http/AdcpHttpClientTest.java b/adcp/src/test/java/org/adcontextprotocol/adcp/http/AdcpHttpClientTest.java new file mode 100644 index 0000000..af206f6 --- /dev/null +++ b/adcp/src/test/java/org/adcontextprotocol/adcp/http/AdcpHttpClientTest.java @@ -0,0 +1,156 @@ +package org.adcontextprotocol.adcp.http; + +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.URI; +import java.net.UnknownHostException; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Tests for {@link AdcpHttpClient} focusing on SSRF protections, + * body capping, and redirect behavior. + * + *

These tests validate the API contract without making real network + * calls where possible. Live HTTP tests are deferred to integration tests. + */ +class AdcpHttpClientTest { + + @Test + void builder_defaults_to_strict_ssrf_policy() { + AdcpHttpClient client = AdcpHttpClient.builder().build(); + assertSame(SsrfPolicy.strict(), client.ssrfPolicy()); + } + + @Test + void builder_sets_max_response_bytes() { + AdcpHttpClient client = AdcpHttpClient.builder() + .maxResponseBytes(1024) + .build(); + assertEquals(1024, client.maxResponseBytes()); + } + + @Test + void builder_rejects_non_positive_max_response_bytes() { + assertThrows(IllegalArgumentException.class, + () -> AdcpHttpClient.builder().maxResponseBytes(0)); + assertThrows(IllegalArgumentException.class, + () -> AdcpHttpClient.builder().maxResponseBytes(-1)); + } + + @Test + void builder_accepts_permissive_policy() { + AdcpHttpClient client = AdcpHttpClient.builder() + .ssrfPolicy(SsrfPolicy.permissive()) + .build(); + assertSame(SsrfPolicy.permissive(), client.ssrfPolicy()); + } + + @Test + void send_rejects_null_uri() { + AdcpHttpClient client = AdcpHttpClient.builder().build(); + assertThrows(NullPointerException.class, + () -> client.send("GET", null, Map.of(), null)); + } + + @Test + void send_blocks_loopback_with_strict_policy() { + AdcpHttpClient client = AdcpHttpClient.builder().build(); + assertThrows(SsrfBlockedException.class, + () -> client.get(URI.create("http://127.0.0.1/test"), Map.of())); + } + + @Test + void send_blocks_metadata_endpoint_with_strict_policy() { + AdcpHttpClient client = AdcpHttpClient.builder().build(); + assertThrows(SsrfBlockedException.class, + () -> client.get( + URI.create("http://169.254.169.254/latest/meta-data/"), + Map.of())); + } + + @Test + void send_blocks_rfc1918_with_strict_policy() { + AdcpHttpClient client = AdcpHttpClient.builder().build(); + assertThrows(SsrfBlockedException.class, + () -> client.get(URI.create("http://10.0.0.1/admin"), Map.of())); + assertThrows(SsrfBlockedException.class, + () -> client.get(URI.create("http://192.168.1.1/"), Map.of())); + } + + @Test + void default_max_response_bytes_is_4kb() { + assertEquals(4096, AdcpHttpClient.DEFAULT_MAX_RESPONSE_BYTES); + } + + @Test + void client_is_autocloseable() { + // Verify AdcpHttpClient implements AutoCloseable + try (AdcpHttpClient client = AdcpHttpClient.builder().build()) { + assertNotNull(client); + } + } + + @Test + void requireHttps_rejects_plain_http_for_remote_hosts() { + AdcpHttpClient client = AdcpHttpClient.builder() + .ssrfPolicy(SsrfPolicy.permissive()) + .requireHttps(true) + .build(); + IOException ex = assertThrows(IOException.class, + () -> client.get(URI.create("http://agent.example.com/mcp"), Map.of())); + assertTrue(ex.getMessage().contains("requireHttps"), + "Error should mention requireHttps: " + ex.getMessage()); + } + + @Test + void requireHttps_allows_localhost_http() { + // Localhost is exempt from requireHttps for local development. + // We verify the requireHttps check passes; downstream errors + // (connection refused, restricted headers, etc.) are expected. + AdcpHttpClient client = AdcpHttpClient.builder() + .ssrfPolicy(SsrfPolicy.permissive()) + .requireHttps(true) + .build(); + try { + client.get(URI.create("http://localhost:4500/mcp"), Map.of()); + // If it succeeds (unlikely in test env), that's fine too + } catch (Exception e) { + // Walk the exception chain — requireHttps rejection must NOT appear + for (Throwable t = e; t != null; t = t.getCause()) { + assertFalse( + t.getMessage() != null && t.getMessage().contains("requireHttps"), + "Localhost should be exempt from requireHttps: " + t.getMessage()); + } + } + } + + @Test + void requireHttps_defaults_to_false() { + // Default behavior should not block http:// via requireHttps + AdcpHttpClient client = AdcpHttpClient.builder() + .ssrfPolicy(SsrfPolicy.permissive()) + .build(); + try { + client.get(URI.create("http://agent.example.com/mcp"), Map.of()); + } catch (Exception e) { + for (Throwable t = e; t != null; t = t.getCause()) { + assertFalse( + t.getMessage() != null && t.getMessage().contains("requireHttps"), + "requireHttps should default to false: " + t.getMessage()); + } + } + } + + @Test + void send_rejects_octal_ip_literal() { + AdcpHttpClient client = AdcpHttpClient.builder() + .ssrfPolicy(SsrfPolicy.permissive()) + .build(); + assertThrows(SsrfBlockedException.class, + () -> client.get(URI.create("http://0177.0.0.1/test"), Map.of())); + } +} diff --git a/adcp/src/test/java/org/adcontextprotocol/adcp/http/DnsPinResolverTest.java b/adcp/src/test/java/org/adcontextprotocol/adcp/http/DnsPinResolverTest.java new file mode 100644 index 0000000..7fdb49c --- /dev/null +++ b/adcp/src/test/java/org/adcontextprotocol/adcp/http/DnsPinResolverTest.java @@ -0,0 +1,73 @@ +package org.adcontextprotocol.adcp.http; + +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.UnknownHostException; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Tests for {@link DnsPinResolver}. + */ +class DnsPinResolverTest { + + @Test + void resolveAndPin_blocks_loopback() { + assertThrows(SsrfBlockedException.class, + () -> DnsPinResolver.resolveAndPin("127.0.0.1", SsrfPolicy.strict())); + } + + @Test + void resolveAndPin_blocks_rfc1918() { + assertThrows(SsrfBlockedException.class, + () -> DnsPinResolver.resolveAndPin("10.0.0.1", SsrfPolicy.strict())); + } + + @Test + void resolveAndPin_blocks_link_local() { + assertThrows(SsrfBlockedException.class, + () -> DnsPinResolver.resolveAndPin("169.254.169.254", SsrfPolicy.strict())); + } + + @Test + void resolveAndPin_allows_with_permissive_policy() throws IOException { + InetAddress addr = DnsPinResolver.resolveAndPin( + "127.0.0.1", SsrfPolicy.permissive()); + assertNotNull(addr); + } + + @Test + void validateAddress_blocks_denied() { + InetAddress addr; + try { + addr = InetAddress.getByName("10.0.0.1"); + } catch (UnknownHostException e) { + fail("Could not resolve 10.0.0.1", e); + return; + } + assertThrows(SsrfBlockedException.class, + () -> DnsPinResolver.validateAddress(addr, SsrfPolicy.strict())); + } + + @Test + void validateAddress_allows_public() throws UnknownHostException { + InetAddress addr = InetAddress.getByName("8.8.8.8"); + assertDoesNotThrow( + () -> DnsPinResolver.validateAddress(addr, SsrfPolicy.strict())); + } + + @Test + void ssrfBlockedException_carries_reason() { + try { + DnsPinResolver.resolveAndPin("127.0.0.1", SsrfPolicy.strict()); + fail("Expected SsrfBlockedException"); + } catch (SsrfBlockedException e) { + assertEquals("127.0.0.1", e.host()); + assertFalse(e.reason().isBlank()); + } catch (IOException e) { + fail("Unexpected IOException", e); + } + } +} diff --git a/adcp/src/test/java/org/adcontextprotocol/adcp/http/StrictSsrfPolicyTest.java b/adcp/src/test/java/org/adcontextprotocol/adcp/http/StrictSsrfPolicyTest.java index d0a1619..7b637a4 100644 --- a/adcp/src/test/java/org/adcontextprotocol/adcp/http/StrictSsrfPolicyTest.java +++ b/adcp/src/test/java/org/adcontextprotocol/adcp/http/StrictSsrfPolicyTest.java @@ -44,7 +44,11 @@ class StrictSsrfPolicyTest { "ff02::1", // IPv6 multicast "::ffff:127.0.0.1", // IPv4-mapped IPv6 loopback "::ffff:10.0.0.1", // IPv4-mapped IPv6 RFC 1918 - "::ffff:169.254.169.254" // IPv4-mapped IPv6 cloud metadata + "::ffff:169.254.169.254", // IPv4-mapped IPv6 cloud metadata + "::127.0.0.1", // IPv4-compatible loopback + "::10.0.0.1", // IPv4-compatible RFC 1918 + "::169.254.169.254", // IPv4-compatible cloud metadata + "::192.168.1.1" // IPv4-compatible RFC 1918 }) void denies_block_table(String literal) throws UnknownHostException { InetAddress addr = InetAddress.getByName(literal); diff --git a/adcp/src/test/java/org/adcontextprotocol/adcp/schema/AdcpObjectMapperFactoryTest.java b/adcp/src/test/java/org/adcontextprotocol/adcp/schema/AdcpObjectMapperFactoryTest.java index 3dad82d..74c519f 100644 --- a/adcp/src/test/java/org/adcontextprotocol/adcp/schema/AdcpObjectMapperFactoryTest.java +++ b/adcp/src/test/java/org/adcontextprotocol/adcp/schema/AdcpObjectMapperFactoryTest.java @@ -50,18 +50,18 @@ void factory_tolerates_unknown_fields() throws Exception { void factory_widens_stream_read_constraints() { ObjectMapper mapper = AdcpObjectMapperFactory.create(); StreamReadConstraints constraints = mapper.getFactory().streamReadConstraints(); - assertTrue(constraints.getMaxStringLength() >= 100_000_000, - "MaxStringLength should be at least 100MB for creative payloads"); - assertTrue(constraints.getMaxNestingDepth() >= 2000, - "Read MaxNestingDepth should be at least 2000 for deep catalog responses"); + assertTrue(constraints.getMaxStringLength() >= 10_000_000, + "MaxStringLength should be at least 10MB for creative payloads"); + assertTrue(constraints.getMaxNestingDepth() >= 200, + "Read MaxNestingDepth should be at least 200 for deep catalog responses"); } @Test void factory_widens_stream_write_constraints() { ObjectMapper mapper = AdcpObjectMapperFactory.create(); StreamWriteConstraints constraints = mapper.getFactory().streamWriteConstraints(); - assertTrue(constraints.getMaxNestingDepth() >= 2000, - "MaxNestingDepth should be at least 2000 for deep catalog responses"); + assertTrue(constraints.getMaxNestingDepth() >= 200, + "MaxNestingDepth should be at least 200 for deep catalog responses"); } @Test diff --git a/adcp/src/test/java/org/adcontextprotocol/adcp/transport/VersionEnvelopeTest.java b/adcp/src/test/java/org/adcontextprotocol/adcp/transport/VersionEnvelopeTest.java new file mode 100644 index 0000000..7f6bd87 --- /dev/null +++ b/adcp/src/test/java/org/adcontextprotocol/adcp/transport/VersionEnvelopeTest.java @@ -0,0 +1,71 @@ +package org.adcontextprotocol.adcp.transport; + +import org.adcontextprotocol.adcp.AdcpVersion; +import org.junit.jupiter.api.Test; + +import java.util.LinkedHashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Tests for {@link VersionEnvelope}. + */ +class VersionEnvelopeTest { + + @Test + void build_default_version() { + Map envelope = VersionEnvelope.build(null); + + assertEquals(3, envelope.get("adcp_major_version")); + assertFalse(envelope.containsKey("adcp_version")); + } + + @Test + void build_v3_explicit() { + Map envelope = VersionEnvelope.build(AdcpVersion.V3); + + assertEquals(3, envelope.get("adcp_major_version")); + assertFalse(envelope.containsKey("adcp_version")); + } + + @Test + void build_v3_1() { + Map envelope = VersionEnvelope.build(AdcpVersion.V3_1); + + assertEquals(3, envelope.get("adcp_major_version")); + assertEquals("3.1", envelope.get("adcp_version")); + } + + @Test + void mergeInto_sdk_version_wins_over_caller() { + Map callerArgs = new LinkedHashMap<>(); + callerArgs.put("adcp_major_version", 99); + callerArgs.put("my_param", "value"); + + Map merged = VersionEnvelope.mergeInto(callerArgs, AdcpVersion.V3); + + // SDK value wins — caller override is discarded with a warning + assertEquals(3, merged.get("adcp_major_version")); + // Caller's own param preserved + assertEquals("value", merged.get("my_param")); + } + + @Test + void mergeInto_injects_version_when_caller_doesnt_set() { + Map callerArgs = Map.of("param", "val"); + + Map merged = VersionEnvelope.mergeInto(callerArgs, AdcpVersion.V3); + + assertEquals(3, merged.get("adcp_major_version")); + assertEquals("val", merged.get("param")); + } + + @Test + void mergeInto_null_callerArgs_returns_envelope_only() { + Map merged = VersionEnvelope.mergeInto(null, AdcpVersion.V3); + + assertEquals(3, merged.get("adcp_major_version")); + assertEquals(1, merged.size()); + } +} diff --git a/adcp/src/test/java/org/adcontextprotocol/adcp/transport/mcp/McpConnectionManagerTest.java b/adcp/src/test/java/org/adcontextprotocol/adcp/transport/mcp/McpConnectionManagerTest.java new file mode 100644 index 0000000..0e57d30 --- /dev/null +++ b/adcp/src/test/java/org/adcontextprotocol/adcp/transport/mcp/McpConnectionManagerTest.java @@ -0,0 +1,57 @@ +package org.adcontextprotocol.adcp.transport.mcp; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Tests for {@link McpConnectionManager}. + * + *

These tests verify the cache management, eviction, and lifecycle + * behavior without making real MCP connections (which require a running + * MCP server). + */ +class McpConnectionManagerTest { + + private final McpConnectionManager manager = new McpConnectionManager(); + + @AfterEach + void cleanup() { + manager.close(); + } + + @Test + void close_clears_cache() { + // Just verify close doesn't throw on empty cache + assertDoesNotThrow(manager::close); + } + + @Test + void implements_autocloseable() { + // Verify the manager can be used in try-with-resources + try (McpConnectionManager mgr = new McpConnectionManager()) { + assertNotNull(mgr); + } + } + + @Test + void evict_nonexistent_is_noop() { + var uri = java.net.URI.create("https://agent.example.com"); + assertDoesNotThrow(() -> manager.evict(uri, "abc")); + } + + @Test + void getOrConnect_after_close_throws() { + manager.close(); + var uri = java.net.URI.create("https://agent.example.com"); + assertThrows(IllegalStateException.class, + () -> manager.getOrConnect(uri, java.util.Map.of(), "hash")); + } + + @Test + void double_close_is_safe() { + manager.close(); + assertDoesNotThrow(manager::close); + } +}