diff --git a/capiscio_sdk/connect.py b/capiscio_sdk/connect.py index f04b78e..8a219e6 100644 --- a/capiscio_sdk/connect.py +++ b/capiscio_sdk/connect.py @@ -92,6 +92,18 @@ def _print_agent_key_capture_hint(agent_id: str, private_jwk: dict) -> None: # Standalone Helper Functions (for testing and direct use) # ============================================================================= +def _did_method(did: str) -> str: + """Extract the method from a DID string (e.g. 'key' from 'did:key:z6Mk...').""" + parts = did.split(":", 3) + return parts[1] if len(parts) >= 3 else "" + + +def _is_did_transition(did_a: str, did_b: str) -> bool: + """Return True if the two DIDs represent a safe did:key ↔ did:web transition.""" + methods = {_did_method(did_a), _did_method(did_b)} + return methods == {"key", "web"} + + def _read_did_from_keys(identity_path: Path) -> Optional[str]: """ Read DID from an identity directory. @@ -552,7 +564,14 @@ def _find_agent_from_local_keys(self) -> Optional[Dict[str, Any]]: server_did = agent_data.get("did") # Verify DID matches if server has one if not server_did or server_did == local_did: - return agent_data + if not self.name or agent_data.get("name") == self.name: + return agent_data + # did:key ↔ did:web transitions are expected when + # server upgrades DID method; agent_id already confirms identity. + if _is_did_transition(local_did, server_did): + logger.debug(f"DID method transition for {agent_id}: {local_did} → {server_did}") + if not self.name or agent_data.get("name") == self.name: + return agent_data logger.debug(f"DID mismatch: local={local_did}, server={server_did}") except Exception: pass @@ -594,8 +613,17 @@ def _find_agent_from_local_keys(self) -> Optional[Dict[str, Any]]: # If server has DID, verify it matches local keys if server_did and local_did and server_did != local_did: - logger.warning(f"DID mismatch for {agent_id}: local={local_did}, server={server_did}") - continue # Don't use mismatched agent + # did:key ↔ did:web transitions are expected when + # server upgrades DID method; agent_id already confirms identity. + if _is_did_transition(local_did, server_did): + logger.debug(f"DID method transition for {agent_id}: {local_did} → {server_did}") + else: + logger.warning(f"DID mismatch for {agent_id}: local={local_did}, server={server_did}") + continue # Don't use mismatched agent + + # If caller specified a name, only match agents with that name + if self.name and agent_data.get("name") != self.name: + continue return agent_data except Exception: @@ -679,8 +707,21 @@ def _init_identity(self) -> str: # Derive DID from public key's kid field (RFC-002 §6.1: did:key is self-describing) try: public_jwk = json.loads(public_key_path.read_text()) - did = public_jwk.get("kid") - if did and did.startswith("did:"): + kid = public_jwk.get("kid") + if kid and kid.startswith("did:"): + # Prefer did:web derived from server_url + agent_id (RFC-002 §6.1). + # The JWK kid is did:key (key confirmation), but the agent's + # primary DID should be did:web::agents:. + did = kid # fallback + if self.server_url and self.agent_id: + try: + from urllib.parse import urlparse + domain = urlparse(self.server_url).netloc + if domain: + did = f"did:web:{domain}:agents:{self.agent_id}" + except Exception: + pass + logger.info(f"Recovered identity from existing keys: {did}") # Ensure DID is registered with server (may have failed previously) diff --git a/tests/unit/test_connect.py b/tests/unit/test_connect.py index 6b06c60..ed30645 100644 --- a/tests/unit/test_connect.py +++ b/tests/unit/test_connect.py @@ -18,6 +18,8 @@ DEFAULT_KEYS_DIR, ENV_AGENT_PRIVATE_KEY, PROD_REGISTRY, + _did_method, + _is_did_transition, _print_agent_key_capture_hint, _public_jwk_from_private, ) @@ -845,7 +847,8 @@ def test_init_identity_uses_existing(self, tmp_path): result = connector._init_identity() - assert result == "did:key:z6MkExisting" + # With server_url + agent_id set, the code prefers did:web over the did:key in the JWK kid + assert result == "did:web:test.server.com:agents:agent-123" def test_init_identity_calls_rpc(self, tmp_path): """Test _init_identity calls capiscio-core RPC.""" @@ -1346,6 +1349,44 @@ def test_missing_kid_regenerates(self, tmp_path): mock_rpc.simpleguard.init.assert_called_once() +class TestDidMethod: + """Tests for the _did_method helper function.""" + + def test_extracts_key_method(self): + assert _did_method("did:key:z6MkTest") == "key" + + def test_extracts_web_method(self): + assert _did_method("did:web:example.com:agents:123") == "web" + + def test_empty_string(self): + assert _did_method("") == "" + + def test_too_few_parts(self): + assert _did_method("did:key") == "" + + def test_other_method(self): + assert _did_method("did:pkh:eip155:1:0x1234") == "pkh" + + +class TestIsDidTransition: + """Tests for the _is_did_transition helper function.""" + + def test_key_to_web(self): + assert _is_did_transition("did:key:z6MkTest", "did:web:example.com:agents:123") is True + + def test_web_to_key(self): + assert _is_did_transition("did:web:example.com:agents:123", "did:key:z6MkTest") is True + + def test_same_method_key(self): + assert _is_did_transition("did:key:z6MkA", "did:key:z6MkB") is False + + def test_same_method_web(self): + assert _is_did_transition("did:web:a.com", "did:web:b.com") is False + + def test_unrelated_methods(self): + assert _is_did_transition("did:key:z6MkTest", "did:pkh:eip155:1:0x1234") is False + + class TestReadDidFromKeys: """Tests for the _read_did_from_keys standalone helper function.""" @@ -1465,6 +1506,231 @@ def test_finds_agent_with_valid_uuid_keys(self, tmp_path): assert result is not None assert result["id"] == agent_uuid + def test_did_transition_accepted_in_scan(self, tmp_path): + """Test that did:key → did:web transitions are accepted during scan.""" + connector = _Connector( + api_key="sk_test", + name=None, + agent_id=None, + server_url="https://test.server.com", + keys_dir=tmp_path, + auto_badge=False, + dev_mode=False, + ) + + agent_uuid = "12345678-1234-1234-1234-123456789012" + uuid_dir = tmp_path / agent_uuid + uuid_dir.mkdir() + (uuid_dir / "private.jwk").write_text('{"kty":"OKP"}') + (uuid_dir / "public.jwk").write_text( + '{"kty":"OKP","kid":"did:key:z6MkLocal"}' + ) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": { + "id": agent_uuid, + "name": "My Agent", + "did": "did:web:example.com:agents:12345678-1234-1234-1234-123456789012", + } + } + connector._client = MagicMock() + connector._client.get = MagicMock(return_value=mock_response) + + result = connector._find_agent_from_local_keys() + + assert result is not None + assert result["id"] == agent_uuid + + def test_did_mismatch_rejected_in_scan(self, tmp_path): + """Test that non-transition DID mismatches are rejected during scan.""" + connector = _Connector( + api_key="sk_test", + name=None, + agent_id=None, + server_url="https://test.server.com", + keys_dir=tmp_path, + auto_badge=False, + dev_mode=False, + ) + + agent_uuid = "12345678-1234-1234-1234-123456789012" + uuid_dir = tmp_path / agent_uuid + uuid_dir.mkdir() + (uuid_dir / "private.jwk").write_text('{"kty":"OKP"}') + (uuid_dir / "public.jwk").write_text( + '{"kty":"OKP","kid":"did:key:z6MkLocal"}' + ) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": { + "id": agent_uuid, + "name": "My Agent", + "did": "did:key:z6MkDifferent", + } + } + connector._client = MagicMock() + connector._client.get = MagicMock(return_value=mock_response) + + result = connector._find_agent_from_local_keys() + + assert result is None + + def test_name_filter_rejects_mismatch_in_scan(self, tmp_path): + """Test that name filter rejects agents with wrong name during scan.""" + connector = _Connector( + api_key="sk_test", + name="Expected Name", + agent_id=None, + server_url="https://test.server.com", + keys_dir=tmp_path, + auto_badge=False, + dev_mode=False, + ) + + agent_uuid = "12345678-1234-1234-1234-123456789012" + uuid_dir = tmp_path / agent_uuid + uuid_dir.mkdir() + (uuid_dir / "private.jwk").write_text('{"kty":"OKP"}') + (uuid_dir / "public.jwk").write_text( + '{"kty":"OKP","kid":"did:key:z6MkLocal"}' + ) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": {"id": agent_uuid, "name": "Wrong Name"} + } + connector._client = MagicMock() + connector._client.get = MagicMock(return_value=mock_response) + + result = connector._find_agent_from_local_keys() + + assert result is None + + def test_flat_keys_did_transition_accepted(self, tmp_path): + """Test did:key → did:web transition in flat keys_dir layout.""" + keys_dir = tmp_path / "12345678-1234-1234-1234-123456789012" + keys_dir.mkdir() + (keys_dir / "private.jwk").write_text('{"kty":"OKP"}') + (keys_dir / "public.jwk").write_text( + '{"kty":"OKP","kid":"did:key:z6MkLocal"}' + ) + + connector = _Connector( + api_key="sk_test", + name=None, + agent_id=None, + server_url="https://test.server.com", + keys_dir=str(keys_dir), + auto_badge=False, + dev_mode=False, + ) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": { + "id": "12345678-1234-1234-1234-123456789012", + "name": "My Agent", + "did": "did:web:example.com:agents:12345678-1234-1234-1234-123456789012", + } + } + connector._client = MagicMock() + connector._client.get = MagicMock(return_value=mock_response) + + result = connector._find_agent_from_local_keys() + + assert result is not None + assert result["id"] == "12345678-1234-1234-1234-123456789012" + + def test_flat_keys_name_filter_rejects_mismatch(self, tmp_path): + """Test name filter rejects mismatched agents in flat keys_dir layout.""" + keys_dir = tmp_path / "12345678-1234-1234-1234-123456789012" + keys_dir.mkdir() + (keys_dir / "private.jwk").write_text('{"kty":"OKP"}') + (keys_dir / "public.jwk").write_text( + '{"kty":"OKP","kid":"did:key:z6MkLocal"}' + ) + + connector = _Connector( + api_key="sk_test", + name="Expected Name", + agent_id=None, + server_url="https://test.server.com", + keys_dir=str(keys_dir), + auto_badge=False, + dev_mode=False, + ) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": { + "id": "12345678-1234-1234-1234-123456789012", + "name": "Wrong Name", + "did": "did:key:z6MkLocal", + } + } + connector._client = MagicMock() + connector._client.get = MagicMock(return_value=mock_response) + + result = connector._find_agent_from_local_keys() + + # Name mismatch — should not return this agent + assert result is None + + +class TestInitIdentityDidWebDerivation: + """Tests for did:web derivation in _init_identity.""" + + def test_did_web_fallback_without_server_url(self, tmp_path): + """Test _init_identity falls back to did:key when server_url is empty.""" + connector = _Connector( + api_key="sk_test", + name="Test", + agent_id="agent-123", + server_url="", + keys_dir=tmp_path, + auto_badge=False, + dev_mode=False, + ) + + (tmp_path / "private.jwk").write_text('{"kty":"OKP","crv":"Ed25519"}') + (tmp_path / "public.jwk").write_text( + '{"kty":"OKP","crv":"Ed25519","kid":"did:key:z6MkExisting"}' + ) + connector._ensure_did_registered = MagicMock(return_value=None) + + result = connector._init_identity() + + assert result == "did:key:z6MkExisting" + + def test_did_web_fallback_without_agent_id(self, tmp_path): + """Test _init_identity falls back to did:key when agent_id is missing.""" + connector = _Connector( + api_key="sk_test", + name="Test", + agent_id=None, + server_url="https://test.server.com", + keys_dir=tmp_path, + auto_badge=False, + dev_mode=False, + ) + + (tmp_path / "private.jwk").write_text('{"kty":"OKP","crv":"Ed25519"}') + (tmp_path / "public.jwk").write_text( + '{"kty":"OKP","crv":"Ed25519","kid":"did:key:z6MkExisting"}' + ) + connector._ensure_did_registered = MagicMock(return_value=None) + + result = connector._init_identity() + + assert result == "did:key:z6MkExisting" + class TestEnsureDidRegisteredMethod: """Tests for _ensure_did_registered method."""