Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 46 additions & 5 deletions capiscio_sdk/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,18 @@ def _print_agent_key_capture_hint(agent_id: str, private_jwk: dict) -> None:
# Standalone Helper Functions (for testing and direct use)
# =============================================================================

def _did_method(did: str) -> str:
"""Extract the method from a DID string (e.g. 'key' from 'did:key:z6Mk...')."""
parts = did.split(":", 3)
return parts[1] if len(parts) >= 3 else ""


def _is_did_transition(did_a: str, did_b: str) -> bool:
"""Return True if the two DIDs represent a safe did:key ↔ did:web transition."""
methods = {_did_method(did_a), _did_method(did_b)}
return methods == {"key", "web"}


def _read_did_from_keys(identity_path: Path) -> Optional[str]:
"""
Read DID from an identity directory.
Expand Down Expand Up @@ -552,7 +564,14 @@ def _find_agent_from_local_keys(self) -> Optional[Dict[str, Any]]:
server_did = agent_data.get("did")
# Verify DID matches if server has one
if not server_did or server_did == local_did:
return agent_data
if not self.name or agent_data.get("name") == self.name:
return agent_data
# did:key ↔ did:web transitions are expected when
# server upgrades DID method; agent_id already confirms identity.
if _is_did_transition(local_did, server_did):
Comment thread
beonde marked this conversation as resolved.
logger.debug(f"DID method transition for {agent_id}: {local_did} → {server_did}")
if not self.name or agent_data.get("name") == self.name:
return agent_data
Comment thread
beonde marked this conversation as resolved.
logger.debug(f"DID mismatch: local={local_did}, server={server_did}")
except Exception:
pass
Expand Down Expand Up @@ -594,8 +613,17 @@ def _find_agent_from_local_keys(self) -> Optional[Dict[str, Any]]:

# If server has DID, verify it matches local keys
if server_did and local_did and server_did != local_did:
logger.warning(f"DID mismatch for {agent_id}: local={local_did}, server={server_did}")
continue # Don't use mismatched agent
# did:key ↔ did:web transitions are expected when
# server upgrades DID method; agent_id already confirms identity.
if _is_did_transition(local_did, server_did):
logger.debug(f"DID method transition for {agent_id}: {local_did} → {server_did}")
else:
logger.warning(f"DID mismatch for {agent_id}: local={local_did}, server={server_did}")
continue # Don't use mismatched agent

# If caller specified a name, only match agents with that name
if self.name and agent_data.get("name") != self.name:
continue

return agent_data
except Exception:
Expand Down Expand Up @@ -679,8 +707,21 @@ def _init_identity(self) -> str:
# Derive DID from public key's kid field (RFC-002 §6.1: did:key is self-describing)
try:
public_jwk = json.loads(public_key_path.read_text())
did = public_jwk.get("kid")
if did and did.startswith("did:"):
kid = public_jwk.get("kid")
if kid and kid.startswith("did:"):
# Prefer did:web derived from server_url + agent_id (RFC-002 §6.1).
# The JWK kid is did:key (key confirmation), but the agent's
# primary DID should be did:web:<domain>:agents:<agent_id>.
did = kid # fallback
if self.server_url and self.agent_id:
try:
from urllib.parse import urlparse
domain = urlparse(self.server_url).netloc
Comment thread
beonde marked this conversation as resolved.
if domain:
did = f"did:web:{domain}:agents:{self.agent_id}"
except Exception:
pass

logger.info(f"Recovered identity from existing keys: {did}")
Comment thread
beonde marked this conversation as resolved.

# Ensure DID is registered with server (may have failed previously)
Expand Down
268 changes: 267 additions & 1 deletion tests/unit/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
DEFAULT_KEYS_DIR,
ENV_AGENT_PRIVATE_KEY,
PROD_REGISTRY,
_did_method,
_is_did_transition,
_print_agent_key_capture_hint,
_public_jwk_from_private,
)
Expand Down Expand Up @@ -845,7 +847,8 @@ def test_init_identity_uses_existing(self, tmp_path):

result = connector._init_identity()

assert result == "did:key:z6MkExisting"
# With server_url + agent_id set, the code prefers did:web over the did:key in the JWK kid
assert result == "did:web:test.server.com:agents:agent-123"

def test_init_identity_calls_rpc(self, tmp_path):
"""Test _init_identity calls capiscio-core RPC."""
Expand Down Expand Up @@ -1346,6 +1349,44 @@ def test_missing_kid_regenerates(self, tmp_path):
mock_rpc.simpleguard.init.assert_called_once()


class TestDidMethod:
"""Tests for the _did_method helper function."""

def test_extracts_key_method(self):
assert _did_method("did:key:z6MkTest") == "key"

def test_extracts_web_method(self):
assert _did_method("did:web:example.com:agents:123") == "web"

def test_empty_string(self):
assert _did_method("") == ""

def test_too_few_parts(self):
assert _did_method("did:key") == ""

def test_other_method(self):
assert _did_method("did:pkh:eip155:1:0x1234") == "pkh"


class TestIsDidTransition:
"""Tests for the _is_did_transition helper function."""

def test_key_to_web(self):
assert _is_did_transition("did:key:z6MkTest", "did:web:example.com:agents:123") is True

def test_web_to_key(self):
assert _is_did_transition("did:web:example.com:agents:123", "did:key:z6MkTest") is True

def test_same_method_key(self):
assert _is_did_transition("did:key:z6MkA", "did:key:z6MkB") is False

def test_same_method_web(self):
assert _is_did_transition("did:web:a.com", "did:web:b.com") is False

def test_unrelated_methods(self):
assert _is_did_transition("did:key:z6MkTest", "did:pkh:eip155:1:0x1234") is False


class TestReadDidFromKeys:
"""Tests for the _read_did_from_keys standalone helper function."""

Expand Down Expand Up @@ -1465,6 +1506,231 @@ def test_finds_agent_with_valid_uuid_keys(self, tmp_path):
assert result is not None
assert result["id"] == agent_uuid

def test_did_transition_accepted_in_scan(self, tmp_path):
"""Test that did:key → did:web transitions are accepted during scan."""
connector = _Connector(
api_key="sk_test",
name=None,
agent_id=None,
server_url="https://test.server.com",
keys_dir=tmp_path,
auto_badge=False,
dev_mode=False,
)

agent_uuid = "12345678-1234-1234-1234-123456789012"
uuid_dir = tmp_path / agent_uuid
uuid_dir.mkdir()
(uuid_dir / "private.jwk").write_text('{"kty":"OKP"}')
(uuid_dir / "public.jwk").write_text(
'{"kty":"OKP","kid":"did:key:z6MkLocal"}'
)

mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"data": {
"id": agent_uuid,
"name": "My Agent",
"did": "did:web:example.com:agents:12345678-1234-1234-1234-123456789012",
}
}
connector._client = MagicMock()
connector._client.get = MagicMock(return_value=mock_response)

result = connector._find_agent_from_local_keys()

assert result is not None
assert result["id"] == agent_uuid

def test_did_mismatch_rejected_in_scan(self, tmp_path):
"""Test that non-transition DID mismatches are rejected during scan."""
connector = _Connector(
api_key="sk_test",
name=None,
agent_id=None,
server_url="https://test.server.com",
keys_dir=tmp_path,
auto_badge=False,
dev_mode=False,
)

agent_uuid = "12345678-1234-1234-1234-123456789012"
uuid_dir = tmp_path / agent_uuid
uuid_dir.mkdir()
(uuid_dir / "private.jwk").write_text('{"kty":"OKP"}')
(uuid_dir / "public.jwk").write_text(
'{"kty":"OKP","kid":"did:key:z6MkLocal"}'
)

mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"data": {
"id": agent_uuid,
"name": "My Agent",
"did": "did:key:z6MkDifferent",
}
}
connector._client = MagicMock()
connector._client.get = MagicMock(return_value=mock_response)

result = connector._find_agent_from_local_keys()

assert result is None

def test_name_filter_rejects_mismatch_in_scan(self, tmp_path):
"""Test that name filter rejects agents with wrong name during scan."""
connector = _Connector(
api_key="sk_test",
name="Expected Name",
agent_id=None,
server_url="https://test.server.com",
keys_dir=tmp_path,
auto_badge=False,
dev_mode=False,
)

agent_uuid = "12345678-1234-1234-1234-123456789012"
uuid_dir = tmp_path / agent_uuid
uuid_dir.mkdir()
(uuid_dir / "private.jwk").write_text('{"kty":"OKP"}')
(uuid_dir / "public.jwk").write_text(
'{"kty":"OKP","kid":"did:key:z6MkLocal"}'
)

mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"data": {"id": agent_uuid, "name": "Wrong Name"}
}
connector._client = MagicMock()
connector._client.get = MagicMock(return_value=mock_response)

result = connector._find_agent_from_local_keys()

assert result is None

def test_flat_keys_did_transition_accepted(self, tmp_path):
"""Test did:key → did:web transition in flat keys_dir layout."""
keys_dir = tmp_path / "12345678-1234-1234-1234-123456789012"
keys_dir.mkdir()
(keys_dir / "private.jwk").write_text('{"kty":"OKP"}')
(keys_dir / "public.jwk").write_text(
'{"kty":"OKP","kid":"did:key:z6MkLocal"}'
)

connector = _Connector(
api_key="sk_test",
name=None,
agent_id=None,
server_url="https://test.server.com",
keys_dir=str(keys_dir),
auto_badge=False,
dev_mode=False,
)

mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"data": {
"id": "12345678-1234-1234-1234-123456789012",
"name": "My Agent",
"did": "did:web:example.com:agents:12345678-1234-1234-1234-123456789012",
}
}
connector._client = MagicMock()
connector._client.get = MagicMock(return_value=mock_response)

result = connector._find_agent_from_local_keys()

assert result is not None
assert result["id"] == "12345678-1234-1234-1234-123456789012"

def test_flat_keys_name_filter_rejects_mismatch(self, tmp_path):
"""Test name filter rejects mismatched agents in flat keys_dir layout."""
keys_dir = tmp_path / "12345678-1234-1234-1234-123456789012"
keys_dir.mkdir()
(keys_dir / "private.jwk").write_text('{"kty":"OKP"}')
(keys_dir / "public.jwk").write_text(
'{"kty":"OKP","kid":"did:key:z6MkLocal"}'
)

connector = _Connector(
api_key="sk_test",
name="Expected Name",
agent_id=None,
server_url="https://test.server.com",
keys_dir=str(keys_dir),
auto_badge=False,
dev_mode=False,
)

mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"data": {
"id": "12345678-1234-1234-1234-123456789012",
"name": "Wrong Name",
"did": "did:key:z6MkLocal",
}
}
connector._client = MagicMock()
connector._client.get = MagicMock(return_value=mock_response)

result = connector._find_agent_from_local_keys()

# Name mismatch — should not return this agent
assert result is None


class TestInitIdentityDidWebDerivation:
"""Tests for did:web derivation in _init_identity."""

def test_did_web_fallback_without_server_url(self, tmp_path):
"""Test _init_identity falls back to did:key when server_url is empty."""
connector = _Connector(
api_key="sk_test",
name="Test",
agent_id="agent-123",
server_url="",
keys_dir=tmp_path,
auto_badge=False,
dev_mode=False,
)

(tmp_path / "private.jwk").write_text('{"kty":"OKP","crv":"Ed25519"}')
(tmp_path / "public.jwk").write_text(
'{"kty":"OKP","crv":"Ed25519","kid":"did:key:z6MkExisting"}'
)
connector._ensure_did_registered = MagicMock(return_value=None)

result = connector._init_identity()

assert result == "did:key:z6MkExisting"

def test_did_web_fallback_without_agent_id(self, tmp_path):
"""Test _init_identity falls back to did:key when agent_id is missing."""
connector = _Connector(
api_key="sk_test",
name="Test",
agent_id=None,
server_url="https://test.server.com",
keys_dir=tmp_path,
auto_badge=False,
dev_mode=False,
)

(tmp_path / "private.jwk").write_text('{"kty":"OKP","crv":"Ed25519"}')
(tmp_path / "public.jwk").write_text(
'{"kty":"OKP","crv":"Ed25519","kid":"did:key:z6MkExisting"}'
)
connector._ensure_did_registered = MagicMock(return_value=None)

result = connector._init_identity()

assert result == "did:key:z6MkExisting"


class TestEnsureDidRegisteredMethod:
"""Tests for _ensure_did_registered method."""
Expand Down
Loading