From 5325457d64457e2b704009caa40e79a501d68466 Mon Sep 17 00:00:00 2001 From: Juan Pablo Vega Date: Thu, 18 Jun 2026 17:52:24 +0200 Subject: [PATCH] feat(api): extract provider connection into shared routerless connections domain Move the provider connection out of /tools into a shared connections domain (rename tool_connections -> gateway_connections) so triggers can reuse it. The /tools/connections HTTP contract is unchanged. - New core/connections (ConnectionsService + ConnectionsGatewayInterface adapter port) and dbs/postgres/connections; ToolsService delegates connection mgmt. - Composio auth verbs move behind ComposioConnectionsAdapter; connections never imports tools. - revoke stays local-only (is_valid=False); cross-domain effect via the shared row; usage() reports consumers. - Migration authored once in core_oss (oss000000002), runs in both editions. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../pytest/acceptance}/tools/__init__.py | 0 .../tools/test_tools_connections.py | 154 +++++++ api/entrypoints/routers.py | 29 +- ...tool_connections_to_gateway_connections.py | 49 +++ api/oss/src/apis/fastapi/tools/models.py | 13 +- api/oss/src/apis/fastapi/tools/router.py | 2 +- api/oss/src/core/connections/__init__.py | 0 api/oss/src/core/connections/dtos.py | 130 ++++++ api/oss/src/core/connections/exceptions.py | 65 +++ api/oss/src/core/connections/interfaces.py | 127 ++++++ .../core/connections/providers/__init__.py | 0 .../providers/composio/__init__.py | 20 + .../connections/providers/composio/adapter.py | 302 +++++++++++++ api/oss/src/core/connections/registry.py | 27 ++ api/oss/src/core/connections/service.py | 327 ++++++++++++++ .../src/core/{tools => connections}/utils.py | 2 +- api/oss/src/core/tools/dtos.py | 95 ---- api/oss/src/core/tools/interfaces.py | 165 ++----- .../core/tools/providers/composio/adapter.py | 225 +--------- api/oss/src/core/tools/service.py | 405 ++++++------------ .../src/dbs/postgres/connections/__init__.py | 0 api/oss/src/dbs/postgres/connections/dao.py | 282 ++++++++++++ api/oss/src/dbs/postgres/connections/dbes.py | 69 +++ .../{tools => connections}/mappings.py | 24 +- .../tools/test_tools_connections.py | 71 +++ .../unit/models/test_lifecycle_conventions.py | 2 +- 26 files changed, 1831 insertions(+), 754 deletions(-) rename api/{oss/src/dbs/postgres => ee/tests/pytest/acceptance}/tools/__init__.py (100%) create mode 100644 api/ee/tests/pytest/acceptance/tools/test_tools_connections.py create mode 100644 api/oss/databases/postgres/migrations/core_oss/versions/oss000000002_rename_tool_connections_to_gateway_connections.py create mode 100644 api/oss/src/core/connections/__init__.py create mode 100644 api/oss/src/core/connections/dtos.py create mode 100644 api/oss/src/core/connections/exceptions.py create mode 100644 api/oss/src/core/connections/interfaces.py create mode 100644 api/oss/src/core/connections/providers/__init__.py create mode 100644 api/oss/src/core/connections/providers/composio/__init__.py create mode 100644 api/oss/src/core/connections/providers/composio/adapter.py create mode 100644 api/oss/src/core/connections/registry.py create mode 100644 api/oss/src/core/connections/service.py rename api/oss/src/core/{tools => connections}/utils.py (96%) create mode 100644 api/oss/src/dbs/postgres/connections/__init__.py create mode 100644 api/oss/src/dbs/postgres/connections/dao.py create mode 100644 api/oss/src/dbs/postgres/connections/dbes.py rename api/oss/src/dbs/postgres/{tools => connections}/mappings.py (80%) create mode 100644 api/oss/tests/pytest/acceptance/tools/test_tools_connections.py diff --git a/api/oss/src/dbs/postgres/tools/__init__.py b/api/ee/tests/pytest/acceptance/tools/__init__.py similarity index 100% rename from api/oss/src/dbs/postgres/tools/__init__.py rename to api/ee/tests/pytest/acceptance/tools/__init__.py diff --git a/api/ee/tests/pytest/acceptance/tools/test_tools_connections.py b/api/ee/tests/pytest/acceptance/tools/test_tools_connections.py new file mode 100644 index 0000000000..a3e93c0361 --- /dev/null +++ b/api/ee/tests/pytest/acceptance/tools/test_tools_connections.py @@ -0,0 +1,154 @@ +"""EE acceptance tests for the /tools/connections contract (WP0). + +Mirrors the OSS suite (oss/tests/pytest/acceptance/tools/test_tools_connections.py) +but exercises /tools/connections as a business-plan, developer-role account. +Under EE the endpoints are gated on the tools permission surface (VIEW_TOOLS for +reads, EDIT_TOOLS for writes); a developer role carries both, so this verifies +the contract behaves once the gate is satisfied. + +The query endpoint is DB-only and needs no Composio credentials — it also proves +the gateway_connections rename landed in EE. Create / revoke make real provider +calls, so those are gated on COMPOSIO_API_KEY. + +Requires a running API. +""" + +import os +from uuid import uuid4 + +import pytest +import requests + +from utils.constants import BASE_TIMEOUT + + +_COMPOSIO_ENABLED = bool(os.getenv("COMPOSIO_API_KEY")) +_requires_composio = pytest.mark.skipif( + not _COMPOSIO_ENABLED, + reason="needs live Composio credentials (COMPOSIO_API_KEY)", +) + + +def _create_developer_business_account(admin_api): + uid = uuid4().hex[:12] + email = f"connections-dev-{uid}@test.agenta.ai" + resp = admin_api( + "POST", + "/admin/simple/accounts/", + json={ + "accounts": { + "u": { + "user": {"email": email}, + "options": { + "create_api_keys": True, + "return_api_keys": True, + "seed_defaults": False, + }, + "subscription": {"plan": "cloud_v0_business"}, + "organization_memberships": [ + { + "organization_ref": {"ref": "org"}, + "user_ref": {"ref": "user"}, + "role": "developer", + } + ], + "workspace_memberships": [ + { + "workspace_ref": {"ref": "wrk"}, + "user_ref": {"ref": "user"}, + "role": "developer", + } + ], + "project_memberships": [ + { + "project_ref": {"ref": "prj"}, + "user_ref": {"ref": "user"}, + "role": "developer", + } + ], + } + } + }, + ) + assert resp.status_code == 200, resp.text + account = resp.json()["accounts"]["u"] + return { + "email": email, + "credentials": f"ApiKey {account['api_keys']['key']}", + } + + +def _delete_account_by_email(admin_api, *, email): + resp = admin_api( + "DELETE", + "/admin/simple/accounts/", + json={"accounts": {"u": {"user": {"email": email}}}, "confirm": "delete"}, + ) + assert resp.status_code == 204, resp.text + + +@pytest.fixture(scope="class") +def connections_api(admin_api, ag_env): + account = _create_developer_business_account(admin_api) + + def _request(method: str, endpoint: str, **kwargs): + headers = kwargs.pop("headers", {}) + headers.setdefault("Authorization", account["credentials"]) + return requests.request( + method=method, + url=f"{ag_env['api_url']}{endpoint}", + headers=headers, + timeout=BASE_TIMEOUT, + **kwargs, + ) + + yield _request + + _delete_account_by_email(admin_api, email=account["email"]) + + +class TestToolsConnectionsQuery: + def test_query_connections_returns_200(self, connections_api): + response = connections_api("POST", "/tools/connections/query") + assert response.status_code == 200 + + def test_query_connections_response_shape(self, connections_api): + body = connections_api("POST", "/tools/connections/query").json() + assert "count" in body + assert "connections" in body + assert isinstance(body["connections"], list) + assert body["count"] == len(body["connections"]) + + +class TestToolsConnectionsGet: + def test_get_unknown_connection_returns_404(self, connections_api): + response = connections_api("GET", f"/tools/connections/{uuid4()}") + assert response.status_code == 404 + + +@_requires_composio +class TestToolsConnectionsLifecycle: + def test_create_revoke_roundtrip(self, connections_api): + slug = f"acc-{uuid4().hex[:8]}" + create = connections_api( + "POST", + "/tools/connections/", + json={ + "connection": { + "slug": slug, + "provider_key": "composio", + "integration_key": "github", + "data": {"auth_scheme": "oauth"}, + } + }, + ) + assert create.status_code == 200, create.text + connection_id = create.json()["connection"]["id"] + + # Local-only revoke (C7/B3): flips is_valid on the shared row, no + # provider call, no cascade. + revoke = connections_api("POST", f"/tools/connections/{connection_id}/revoke") + assert revoke.status_code == 200, revoke.text + assert revoke.json()["connection"]["flags"]["is_valid"] is False + + connections_api("DELETE", f"/tools/connections/{connection_id}") diff --git a/api/entrypoints/routers.py b/api/entrypoints/routers.py index d90b38c5f1..c235e817fd 100644 --- a/api/entrypoints/routers.py +++ b/api/entrypoints/routers.py @@ -134,7 +134,10 @@ from oss.src.core.accounts.service import PlatformAdminAccountsService from oss.src.apis.fastapi.accounts.router import PlatformAdminAccountsRouter -from oss.src.dbs.postgres.tools.dao import ToolsDAO +from oss.src.dbs.postgres.connections.dao import ConnectionsDAO +from oss.src.core.connections.providers.composio import ComposioConnectionsAdapter +from oss.src.core.connections.registry import ConnectionsGatewayRegistry +from oss.src.core.connections.service import ConnectionsService from oss.src.core.tools.providers.composio import ComposioToolsAdapter from oss.src.core.tools.registry import ToolsGatewayRegistry from oss.src.core.tools.service import ToolsService @@ -209,6 +212,9 @@ async def lifespan(*args, **kwargs): for adapter in _composio_adapters.values(): await adapter.close() + for adapter in _composio_connections_adapters.values(): + await adapter.close() + await _transactions_engine.close() await _analytics_engine.close() await _streams_engine.close() @@ -439,7 +445,7 @@ async def lifespan(*args, **kwargs): evaluations_dao = EvaluationsDAO(engine=_transactions_engine) folders_dao = FoldersDAO(engine=_transactions_engine) -tools_dao = ToolsDAO(engine=_transactions_engine) +connections_dao = ConnectionsDAO(engine=_transactions_engine) # SERVICES --------------------------------------------------------------------- @@ -574,6 +580,23 @@ async def lifespan(*args, **kwargs): simple_evaluations_service=simple_evaluations_service, ) +# Connections adapter + service (owns gateway_connections; consumed by tools) +_composio_connections_adapters = {} +if env.composio.enabled: + _composio_connections_adapters["composio"] = ComposioConnectionsAdapter( + api_key=env.composio.api_key, # type: ignore[arg-type] # guarded by .enabled + api_url=env.composio.api_url, + ) + +connections_adapter_registry = ConnectionsGatewayRegistry( + adapters=_composio_connections_adapters, +) + +connections_service = ConnectionsService( + connections_dao=connections_dao, + adapter_registry=connections_adapter_registry, +) + # Tools adapter + service _composio_adapters = {} if env.composio.enabled: @@ -589,7 +612,7 @@ async def lifespan(*args, **kwargs): ) tools_service = ToolsService( - tools_dao=tools_dao, + connections_service=connections_service, adapter_registry=tools_adapter_registry, ) diff --git a/api/oss/databases/postgres/migrations/core_oss/versions/oss000000002_rename_tool_connections_to_gateway_connections.py b/api/oss/databases/postgres/migrations/core_oss/versions/oss000000002_rename_tool_connections_to_gateway_connections.py new file mode 100644 index 0000000000..0eca1077c6 --- /dev/null +++ b/api/oss/databases/postgres/migrations/core_oss/versions/oss000000002_rename_tool_connections_to_gateway_connections.py @@ -0,0 +1,49 @@ +"""rename tool_connections to gateway_connections + +Connection ownership moves out of /tools into the shared, routerless +connections domain (gateway-triggers WP0). Rename-only — no data transform. +Authored once in the shared core_oss chain so it runs in BOTH editions; the +legacy chain that created tool_connections is parked. + +Revision ID: oss000000002 +Revises: oss000000001 +Create Date: 2026-06-18 00:00:00.000000 + +""" + +from typing import Sequence, Union + +from alembic import op + + +# revision identifiers, used by Alembic. +revision: str = "oss000000002" +down_revision: Union[str, None] = "oss000000001" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.rename_table("tool_connections", "gateway_connections") + op.execute( + "ALTER TABLE gateway_connections " + "RENAME CONSTRAINT uq_tool_connections_project_provider_integration_slug " + "TO uq_gateway_connections_project_provider_integration_slug" + ) + op.execute( + "ALTER INDEX ix_tool_connections_project_provider_integration " + "RENAME TO ix_gateway_connections_project_provider_integration" + ) + + +def downgrade() -> None: + op.execute( + "ALTER INDEX ix_gateway_connections_project_provider_integration " + "RENAME TO ix_tool_connections_project_provider_integration" + ) + op.execute( + "ALTER TABLE gateway_connections " + "RENAME CONSTRAINT uq_gateway_connections_project_provider_integration_slug " + "TO uq_tool_connections_project_provider_integration_slug" + ) + op.rename_table("gateway_connections", "tool_connections") diff --git a/api/oss/src/apis/fastapi/tools/models.py b/api/oss/src/apis/fastapi/tools/models.py index 891b276c22..49aa25070a 100644 --- a/api/oss/src/apis/fastapi/tools/models.py +++ b/api/oss/src/apis/fastapi/tools/models.py @@ -2,6 +2,10 @@ from pydantic import BaseModel +from oss.src.core.connections.dtos import ( + Connection, + ConnectionCreate, +) from oss.src.core.tools.dtos import ( # Tool Catalog ToolCatalogAction, @@ -10,9 +14,6 @@ ToolCatalogIntegrationDetails, ToolCatalogProvider, ToolCatalogProviderDetails, - # Tool Connections - ToolConnection, - ToolConnectionCreate, # Tool Calls ToolResult, ) @@ -67,17 +68,17 @@ class ToolCatalogActionsResponse(BaseModel): class ToolConnectionCreateRequest(BaseModel): - connection: ToolConnectionCreate + connection: ConnectionCreate class ToolConnectionResponse(BaseModel): count: int = 0 - connection: Optional[ToolConnection] = None + connection: Optional[Connection] = None class ToolConnectionsResponse(BaseModel): count: int = 0 - connections: List[ToolConnection] = [] + connections: List[Connection] = [] # --------------------------------------------------------------------------- diff --git a/api/oss/src/apis/fastapi/tools/router.py b/api/oss/src/apis/fastapi/tools/router.py index 043d114fa7..b59cd36417 100644 --- a/api/oss/src/apis/fastapi/tools/router.py +++ b/api/oss/src/apis/fastapi/tools/router.py @@ -50,7 +50,7 @@ from oss.src.core.tools.service import ( ToolsService, ) -from oss.src.core.tools.utils import decode_oauth_state +from oss.src.core.connections.utils import decode_oauth_state from oss.src.utils.env import env _SLUG_SEGMENT_RE = re.compile(r"^[a-zA-Z0-9_-]+$") diff --git a/api/oss/src/core/connections/__init__.py b/api/oss/src/core/connections/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/oss/src/core/connections/dtos.py b/api/oss/src/core/connections/dtos.py new file mode 100644 index 0000000000..7655536508 --- /dev/null +++ b/api/oss/src/core/connections/dtos.py @@ -0,0 +1,130 @@ +from enum import Enum +from typing import Any, Dict, Optional + +from pydantic import BaseModel + +from oss.src.core.shared.dtos import ( + Header, + Identifier, + Lifecycle, + Metadata, + Slug, + Json, +) + +# --------------------------------------------------------------------------- +# Connection Enums +# --------------------------------------------------------------------------- + + +class ConnectionProviderKind(str, Enum): + COMPOSIO = "composio" + AGENTA = "agenta" + + +class ConnectionAuthScheme(str, Enum): + OAUTH = "oauth" + API_KEY = "api_key" + + +# --------------------------------------------------------------------------- +# Connections (domain DTOs) +# --------------------------------------------------------------------------- + + +class ConnectionStatus(BaseModel): + redirect_url: Optional[str] = None + + +class ConnectionCreateData(BaseModel): + callback_url: Optional[str] = None + # + auth_scheme: Optional[ConnectionAuthScheme] = None + + +class Connection( + Identifier, + Slug, + Header, + Lifecycle, + Metadata, +): + provider_key: ConnectionProviderKind + integration_key: str + # + data: Optional[Json] = None + # + status: Optional[ConnectionStatus] = None + + @property + def provider_connection_id(self) -> Optional[str]: + """Get provider-specific connection ID from data.""" + if self.data and isinstance(self.data, dict): + # For Composio, it's stored as "connected_account_id" + return self.data.get("connected_account_id") or self.data.get( + "provider_connection_id" + ) + return None + + @property + def is_active(self) -> bool: + """Check if connection is active (not deleted).""" + if self.flags and isinstance(self.flags, dict): + return self.flags.get("is_active", False) + return False + + @property + def is_valid(self) -> bool: + """Check if connection is valid (authenticated).""" + if self.flags and isinstance(self.flags, dict): + return self.flags.get("is_valid", False) + return False + + +class ConnectionCreate( + Slug, + Header, + Metadata, +): + provider_key: ConnectionProviderKind + integration_key: str + # + data: Optional[ConnectionCreateData] = None + + +class Usage(BaseModel): + """Cross-domain usage of a connection (C7). + + Reports how many consumers reference a given connection. ``tools`` is True + when the connection backs the tools domain; ``subscriptions`` counts trigger + subscriptions that read the same shared row. + """ + + tools: bool = False + subscriptions: int = 0 + + +# --------------------------------------------------------------------------- +# Connection (adapter-level DTOs) +# --------------------------------------------------------------------------- + + +class ConnectionRequest(BaseModel): + """Input DTO for initiating a provider connection via a gateway adapter.""" + + user_id: str + integration_key: str + auth_scheme: Optional[str] = None + callback_url: Optional[str] = None + + +class ConnectionResponse(BaseModel): + """Output DTO from ConnectionsGatewayInterface.initiate_connection. + + The adapter builds ``connection_data`` with provider-specific fields so the + service never needs to know which provider it is talking to. + """ + + provider_connection_id: str + redirect_url: Optional[str] = None + connection_data: Dict[str, Any] = {} diff --git a/api/oss/src/core/connections/exceptions.py b/api/oss/src/core/connections/exceptions.py new file mode 100644 index 0000000000..5be6636a72 --- /dev/null +++ b/api/oss/src/core/connections/exceptions.py @@ -0,0 +1,65 @@ +from typing import Optional + + +class ConnectionsError(Exception): + """Base exception for the connections domain.""" + + def __init__(self, message: str = "Connections error"): + self.message = message + super().__init__(self.message) + + +class ProviderNotFoundError(ConnectionsError): + """Raised when the requested provider_key has no registered adapter.""" + + def __init__(self, provider_key: str): + self.provider_key = provider_key + super().__init__(f"Provider not found: {provider_key}") + + +class ConnectionNotFoundError(ConnectionsError): + """Raised when a connection cannot be found.""" + + def __init__( + self, + *, + connection_id: Optional[str] = None, + ): + self.connection_id = connection_id + super().__init__(f"Connection not found: {connection_id}") + + +class ConnectionInactiveError(ConnectionsError): + """Raised when trying to use an inactive or revoked connection.""" + + def __init__( + self, + *, + connection_id: str, + detail: Optional[str] = None, + ): + self.connection_id = connection_id + self.detail = detail + msg = f"Connection is inactive or revoked: {connection_id}" + if detail: + msg += f" - {detail}" + super().__init__(msg) + + +class AdapterError(ConnectionsError): + """Raised when an adapter operation fails.""" + + def __init__( + self, + *, + provider_key: str, + operation: str, + detail: Optional[str] = None, + ): + self.provider_key = provider_key + self.operation = operation + self.detail = detail + msg = f"Adapter error ({provider_key}.{operation})" + if detail: + msg += f": {detail}" + super().__init__(msg) diff --git a/api/oss/src/core/connections/interfaces.py b/api/oss/src/core/connections/interfaces.py new file mode 100644 index 0000000000..9d32d295fc --- /dev/null +++ b/api/oss/src/core/connections/interfaces.py @@ -0,0 +1,127 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional +from uuid import UUID + +from oss.src.core.connections.dtos import ( + Connection, + ConnectionCreate, + ConnectionRequest, + ConnectionResponse, +) + + +class ConnectionsDAOInterface(ABC): + """Connection persistence contract — owns the gateway_connections table.""" + + @abstractmethod + async def create_connection( + self, + *, + project_id: UUID, + user_id: UUID, + # + connection_create: ConnectionCreate, + ) -> Optional[Connection]: ... + + @abstractmethod + async def get_connection( + self, + *, + project_id: UUID, + connection_id: UUID, + ) -> Optional[Connection]: ... + + @abstractmethod + async def update_connection( + self, + *, + project_id: UUID, + connection_id: UUID, + # + is_valid: Optional[bool] = None, + is_active: Optional[bool] = None, + provider_connection_id: Optional[str] = None, + data_update: Optional[Dict[str, Any]] = None, + ) -> Optional[Connection]: ... + + @abstractmethod + async def delete_connection( + self, + *, + project_id: UUID, + connection_id: UUID, + ) -> bool: ... + + @abstractmethod + async def query_connections( + self, + *, + project_id: UUID, + # + provider_key: Optional[str] = None, + integration_key: Optional[str] = None, + is_active: Optional[bool] = True, + ) -> List[Connection]: ... + + @abstractmethod + async def find_connection_by_provider_id( + self, + *, + provider_connection_id: str, + ) -> Optional[Connection]: ... + + @abstractmethod + async def activate_connection_by_provider_id( + self, + *, + provider_connection_id: str, + project_id: Optional[UUID] = None, + ) -> Optional[Connection]: ... + + +class ConnectionsGatewayInterface(ABC): + """Adapter port for external connection providers (Composio, Agenta, etc.). + + Provider-keyed on ``provider_connection_id`` and returns provider data. + Holds only the auth verbs; tool-specific verbs (execute, catalog) stay on + ``ToolsGatewayInterface``. + """ + + @abstractmethod + async def initiate_connection( + self, + *, + request: ConnectionRequest, + ) -> ConnectionResponse: + """Initiate a provider-side connection. Returns a typed response with + provider_connection_id, redirect_url, and connection_data — the dict + the service will persist in the local connection record. + """ + ... + + @abstractmethod + async def get_connection_status( + self, + *, + provider_connection_id: str, + ) -> Dict[str, Any]: + """Poll provider for updated connection status.""" + ... + + @abstractmethod + async def refresh_connection( + self, + *, + provider_connection_id: str, + force: bool = False, + callback_url: Optional[str] = None, + integration_key: Optional[str] = None, + user_id: Optional[str] = None, + ) -> Dict[str, Any]: ... + + @abstractmethod + async def revoke_connection( + self, + *, + provider_connection_id: str, + ) -> bool: ... diff --git a/api/oss/src/core/connections/providers/__init__.py b/api/oss/src/core/connections/providers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/oss/src/core/connections/providers/composio/__init__.py b/api/oss/src/core/connections/providers/composio/__init__.py new file mode 100644 index 0000000000..57aad3220b --- /dev/null +++ b/api/oss/src/core/connections/providers/composio/__init__.py @@ -0,0 +1,20 @@ +# Avoid importing adapter here to prevent SDK dependency issues in standalone scripts. +# Import directly when needed: +# from oss.src.core.connections.providers.composio.adapter import ( +# ComposioConnectionsAdapter, +# ) + +__all__ = [ + "ComposioConnectionsAdapter", +] + + +def __getattr__(name): + """Lazy import to avoid SDK dependency on module import.""" + if name == "ComposioConnectionsAdapter": + from oss.src.core.connections.providers.composio.adapter import ( + ComposioConnectionsAdapter, + ) + + return ComposioConnectionsAdapter + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/api/oss/src/core/connections/providers/composio/adapter.py b/api/oss/src/core/connections/providers/composio/adapter.py new file mode 100644 index 0000000000..dd896c31eb --- /dev/null +++ b/api/oss/src/core/connections/providers/composio/adapter.py @@ -0,0 +1,302 @@ +from typing import Any, Dict, Optional + +import httpx + +from oss.src.utils.logging import get_module_logger + +from oss.src.core.connections.dtos import ( + ConnectionRequest, + ConnectionResponse, +) +from oss.src.core.connections.interfaces import ConnectionsGatewayInterface +from oss.src.core.connections.exceptions import AdapterError + + +log = get_module_logger(__name__) + +COMPOSIO_DEFAULT_API_URL = "https://backend.composio.dev/api/v3" + + +class ComposioConnectionsAdapter(ConnectionsGatewayInterface): + """Composio V3 connection auth adapter — uses httpx directly (no SDK). + + Holds the four auth verbs (initiate / status / refresh / revoke) behind + ``ConnectionsGatewayInterface``. Catalog and tool execution stay on the + tools adapter. + """ + + def __init__( + self, + *, + api_key: str, + api_url: str = COMPOSIO_DEFAULT_API_URL, + ): + self.api_key = api_key + self.api_url = api_url.rstrip("/") + # Shared client — one connection pool for the adapter's lifetime. + # Call close() on shutdown (wired in entrypoints/routers.py lifespan). + self._client = httpx.AsyncClient(timeout=30.0) + + async def close(self) -> None: + """Close the shared HTTP client and release connection pool resources.""" + await self._client.aclose() + + def _headers(self) -> Dict[str, str]: + return { + "x-api-key": self.api_key, + "Content-Type": "application/json", + } + + async def _get( + self, + path: str, + *, + params: Optional[Dict[str, Any]] = None, + ) -> Any: + resp = await self._client.get( + f"{self.api_url}{path}", + headers=self._headers(), + params=params, + ) + resp.raise_for_status() + return resp.json() + + async def _post( + self, + path: str, + *, + json: Optional[Dict[str, Any]] = None, + ) -> Any: + resp = await self._client.post( + f"{self.api_url}{path}", + headers=self._headers(), + json=json or {}, + ) + if not resp.is_success: + log.error( + "Composio POST %s → %s: %s", + path, + resp.status_code, + resp.text, + ) + resp.raise_for_status() + return resp.json() + + async def _delete(self, path: str) -> bool: + resp = await self._client.delete( + f"{self.api_url}{path}", + headers=self._headers(), + ) + resp.raise_for_status() + return True + + # ----------------------------------------------------------------------- + # Connections + # ----------------------------------------------------------------------- + + async def initiate_connection( + self, + *, + request: ConnectionRequest, + ) -> ConnectionResponse: + user_id = request.user_id + integration_key = request.integration_key + auth_scheme = request.auth_scheme + callback_url = request.callback_url + + # Step 1: validate the toolkit exists and get its auth scheme info. + try: + toolkit = await self._get(f"/toolkits/{integration_key}") + except httpx.HTTPStatusError as e: + if e.response.status_code == 404: + raise AdapterError( + provider_key="composio", + operation="initiate_connection.validate_toolkit", + detail=f"Integration '{integration_key}' not found", + ) from e + raise AdapterError( + provider_key="composio", + operation="initiate_connection.validate_toolkit", + detail=str(e), + ) from e + except httpx.HTTPError as e: + raise AdapterError( + provider_key="composio", + operation="initiate_connection.validate_toolkit", + detail=str(e), + ) from e + + # Step 2: create an auth config for this integration. + # api_key → use_custom_auth; Composio's redirect UI collects the credentials. + # oauth / None → use_composio_managed_auth. + log.info( + "initiate_connection: integration_key=%s auth_scheme=%r", + integration_key, + auth_scheme, + ) + + if auth_scheme == "api_key": + # Derive Composio authScheme from toolkit's auth_config_details. + # Fall back to "API_KEY" as the common default. + composio_auth_scheme = "API_KEY" + for detail in toolkit.get("auth_config_details") or []: + mode = detail.get("mode", "") + if mode and "oauth" not in mode.lower(): + composio_auth_scheme = mode + break + + auth_config_body: Dict[str, Any] = { + "type": "use_custom_auth", + "authScheme": composio_auth_scheme, + } + else: + auth_config_body = {"type": "use_composio_managed_auth"} + + auth_configs_payload = { + "toolkit": {"slug": integration_key}, + "auth_config": auth_config_body, + } + log.info( + "initiate_connection: POST /auth_configs payload=%s", auth_configs_payload + ) + + try: + auth_config_result = await self._post( + "/auth_configs", + json=auth_configs_payload, + ) + except httpx.HTTPError as e: + raise AdapterError( + provider_key="composio", + operation="initiate_connection.create_auth_config", + detail=str(e), + ) from e + + auth_config_id = (auth_config_result.get("auth_config") or {}).get("id") + if not auth_config_id: + raise AdapterError( + provider_key="composio", + operation="initiate_connection.create_auth_config", + detail=f"No auth_config_id in response for integration '{integration_key}'", + ) + + log.info( + "initiate_connection: integration_key=%s auth_config_id=%s", + integration_key, + auth_config_id, + ) + + # Step 3: initiate connected account link. + payload: Dict[str, Any] = { + "user_id": user_id, + "auth_config_id": auth_config_id, + } + if callback_url: + payload["callback_url"] = callback_url + + try: + result = await self._post("/connected_accounts/link", json=payload) + except httpx.HTTPError as e: + raise AdapterError( + provider_key="composio", + operation="initiate_connection", + detail=str(e), + ) from e + + provider_connection_id = result.get("connected_account_id", "") + redirect_url = result.get("redirect_url") + + connection_data: Dict[str, Any] = { + "connected_account_id": provider_connection_id, + "auth_config_id": auth_config_id, + } + if redirect_url: + connection_data["redirect_url"] = redirect_url + + return ConnectionResponse( + provider_connection_id=provider_connection_id, + redirect_url=redirect_url, + connection_data=connection_data, + ) + + async def get_connection_status( + self, + *, + provider_connection_id: str, + ) -> Dict[str, Any]: + try: + result = await self._get(f"/connected_accounts/{provider_connection_id}") + except httpx.HTTPError as e: + raise AdapterError( + provider_key="composio", + operation="get_connection_status", + detail=str(e), + ) from e + + return { + "status": result.get("status"), + "is_valid": result.get("status") == "ACTIVE", + } + + async def refresh_connection( + self, + *, + provider_connection_id: str, + force: bool = False, + callback_url: Optional[str] = None, + integration_key: Optional[str] = None, + user_id: Optional[str] = None, + ) -> Dict[str, Any]: + # For Composio OAuth flows, "refresh" means re-initiating the auth link. + # The provider does not expose a token-refresh endpoint for OAuth connections, + # so we create a new connected_accounts/link which the user must re-authorize. + if integration_key and user_id: + result = await self.initiate_connection( + request=ConnectionRequest( + user_id=user_id, + integration_key=integration_key, + callback_url=callback_url, + ), + ) + return { + "id": result.provider_connection_id, + "redirect_url": result.redirect_url, + "auth_config_id": result.connection_data.get("auth_config_id"), + "is_valid": False, # Re-auth pending until callback fires + } + + payload: Dict[str, Any] = {} + if callback_url: + payload["callback_url"] = callback_url + + try: + result = await self._post( + f"/connected_accounts/{provider_connection_id}/refresh", + json=payload, + ) + except httpx.HTTPError as e: + raise AdapterError( + provider_key="composio", + operation="refresh_connection", + detail=str(e), + ) from e + + return { + "status": result.get("status"), + "is_valid": result.get("status") == "ACTIVE", + "redirect_url": result.get("redirect_url"), + } + + async def revoke_connection( + self, + *, + provider_connection_id: str, + ) -> bool: + try: + return await self._delete(f"/connected_accounts/{provider_connection_id}") + except httpx.HTTPError as e: + raise AdapterError( + provider_key="composio", + operation="revoke_connection", + detail=str(e), + ) from e diff --git a/api/oss/src/core/connections/registry.py b/api/oss/src/core/connections/registry.py new file mode 100644 index 0000000000..347895dd45 --- /dev/null +++ b/api/oss/src/core/connections/registry.py @@ -0,0 +1,27 @@ +from typing import Dict, ItemsView + +from oss.src.core.connections.interfaces import ConnectionsGatewayInterface +from oss.src.core.connections.exceptions import ProviderNotFoundError + + +class ConnectionsGatewayRegistry: + """Dispatches to the correct connection adapter based on provider_key.""" + + def __init__( + self, + *, + adapters: Dict[str, ConnectionsGatewayInterface], + ): + self._adapters = adapters + + def get(self, provider_key: str) -> ConnectionsGatewayInterface: + adapter = self._adapters.get(provider_key) + if not adapter: + raise ProviderNotFoundError(provider_key) + return adapter + + def keys(self) -> list[str]: + return list(self._adapters.keys()) + + def items(self) -> ItemsView[str, ConnectionsGatewayInterface]: + return self._adapters.items() diff --git a/api/oss/src/core/connections/service.py b/api/oss/src/core/connections/service.py new file mode 100644 index 0000000000..026f7d3045 --- /dev/null +++ b/api/oss/src/core/connections/service.py @@ -0,0 +1,327 @@ +from typing import Any, Dict, List, Optional +from uuid import UUID + +from oss.src.utils.logging import get_module_logger +from oss.src.utils.env import env + +from oss.src.core.connections.dtos import ( + Connection, + ConnectionCreate, + ConnectionRequest, + Usage, +) +from oss.src.core.connections.interfaces import ConnectionsDAOInterface +from oss.src.core.connections.registry import ConnectionsGatewayRegistry +from oss.src.core.connections.exceptions import ( + ConnectionInactiveError, + ConnectionNotFoundError, +) +from oss.src.core.connections.utils import make_oauth_state + + +log = get_module_logger(__name__) + +# The OAuth callback stays on the /tools router so the public contract is +# unchanged even though the connection now lives in its own domain. +_CALLBACK_PATH = "/tools/connections/callback" + + +class ConnectionsService: + """Project-scoped service that owns gateway_connections. + + Returns domain ``Connection`` DTOs. Downstream domains (tools, triggers) + consume this service; it never imports from them. + """ + + def __init__( + self, + *, + connections_dao: ConnectionsDAOInterface, + adapter_registry: ConnectionsGatewayRegistry, + ): + self.connections_dao = connections_dao + self.adapter_registry = adapter_registry + + # ----------------------------------------------------------------------- + # Reads + # ----------------------------------------------------------------------- + + async def query_connections( + self, + *, + project_id: UUID, + # + provider_key: Optional[str] = None, + integration_key: Optional[str] = None, + is_active: Optional[bool] = True, + ) -> List[Connection]: + """Query connections with optional filtering. Defaults to active-only.""" + return await self.connections_dao.query_connections( + project_id=project_id, + provider_key=provider_key, + integration_key=integration_key, + is_active=is_active, + ) + + async def list_connections( + self, + *, + project_id: UUID, + provider_key: str, + integration_key: str, + ) -> List[Connection]: + """List connections for a specific integration (catalog enrichment).""" + return await self.connections_dao.query_connections( + project_id=project_id, + provider_key=provider_key, + integration_key=integration_key, + ) + + async def get_connection( + self, + *, + project_id: UUID, + connection_id: UUID, + ) -> Optional[Connection]: + """Return a single connection by ID scoped to the project, or None.""" + # Read-only by design: do not mutate local state during GET. + return await self.connections_dao.get_connection( + project_id=project_id, + connection_id=connection_id, + ) + + async def find_connection_by_provider_connection_id( + self, + *, + provider_connection_id: str, + ) -> Optional[Connection]: + """Find any connection by its provider-side ID (for OAuth callbacks).""" + return await self.connections_dao.find_connection_by_provider_id( + provider_connection_id=provider_connection_id, + ) + + async def activate_connection_by_provider_connection_id( + self, + *, + provider_connection_id: str, + project_id: Optional[UUID] = None, + ) -> Optional[Connection]: + """Mark a connection valid+active after OAuth completes.""" + return await self.connections_dao.activate_connection_by_provider_id( + provider_connection_id=provider_connection_id, + project_id=project_id, + ) + + async def usage( + self, + *, + project_id: UUID, + connection_id: UUID, + ) -> Usage: + """Report cross-domain usage of a connection (C7). + + The seam for "used by tools / N subs". Tools and triggers read the same + shared row, so this is a read-only count of consumers. Subscriptions are + not yet a consumer in this WP, so the count is the seam (0). + """ + conn = await self.connections_dao.get_connection( + project_id=project_id, + connection_id=connection_id, + ) + if not conn: + raise ConnectionNotFoundError(connection_id=str(connection_id)) + + return Usage( + tools=True, + subscriptions=0, + ) + + # ----------------------------------------------------------------------- + # Writes + # ----------------------------------------------------------------------- + + async def initiate_connection( + self, + *, + project_id: UUID, + user_id: UUID, + # + connection_create: ConnectionCreate, + ) -> Connection: + """Initiate a provider connection and persist it locally in pending state.""" + provider_key = connection_create.provider_key.value + integration_key = connection_create.integration_key + + adapter = self.adapter_registry.get(provider_key) + + # Callback URL is server-owned. Do not trust/require client-provided values. + # Embed a signed state token so the callback can scope the activation. + state = make_oauth_state( + project_id=project_id, + user_id=user_id, + secret_key=env.agenta.crypt_key, + ) + callback_url = f"{env.agenta.api_url}{_CALLBACK_PATH}?state={state}" + + # Initiate with provider + connection_create_data = connection_create.data + provider_result = await adapter.initiate_connection( + request=ConnectionRequest( + user_id=str(project_id), + integration_key=integration_key, + auth_scheme=connection_create_data.auth_scheme.value + if connection_create_data and connection_create_data.auth_scheme + else None, + callback_url=callback_url, + ), + ) + + # Merge provider-returned connection_data with service-level project_id. + # The adapter owns provider-specific field names; the service adds project scope. + data: Dict[str, Any] = dict(provider_result.connection_data) + data["project_id"] = str(project_id) + connection_create.data = data # type: ignore[assignment] + + # Persist locally + return await self.connections_dao.create_connection( + project_id=project_id, + user_id=user_id, + # + connection_create=connection_create, + ) + + async def delete_connection( + self, + *, + project_id: UUID, + connection_id: UUID, + ) -> bool: + """Revoke provider-side connection and delete locally. Raises ConnectionNotFoundError if missing.""" + conn = await self.connections_dao.get_connection( + project_id=project_id, + connection_id=connection_id, + ) + + if not conn: + raise ConnectionNotFoundError( + connection_id=str(connection_id), + ) + + # Revoke provider-side + if conn.provider_connection_id: + adapter = self.adapter_registry.get(conn.provider_key.value) + try: + await adapter.revoke_connection( + provider_connection_id=conn.provider_connection_id, + ) + except Exception: + log.warning( + "Failed to revoke provider connection %s, proceeding with local delete", + conn.provider_connection_id, + ) + + # Delete locally + return await self.connections_dao.delete_connection( + project_id=project_id, + connection_id=connection_id, + ) + + async def revoke_connection( + self, + *, + project_id: UUID, + connection_id: UUID, + ) -> Connection: + """Mark a connection invalid locally without touching the provider. + + Local-only by design (C7/B3): flipping ``is_valid=False`` on the shared + gateway_connections row is the cross-domain effect — tools and triggers + read the same row, so everyone sees the revocation without a provider + call or cascade. + """ + conn = await self.connections_dao.get_connection( + project_id=project_id, + connection_id=connection_id, + ) + + if not conn: + raise ConnectionNotFoundError( + connection_id=str(connection_id), + ) + + updated = await self.connections_dao.update_connection( + project_id=project_id, + connection_id=connection_id, + is_valid=False, + ) + + return updated or conn + + async def refresh_connection( + self, + *, + project_id: UUID, + connection_id: UUID, + # + force: bool = False, + ) -> Connection: + conn = await self.connections_dao.get_connection( + project_id=project_id, + connection_id=connection_id, + ) + + if not conn: + raise ConnectionNotFoundError( + connection_id=str(connection_id), + ) + + if not conn.provider_connection_id: + raise ConnectionNotFoundError( + connection_id=str(connection_id), + ) + + if not conn.is_active: + raise ConnectionInactiveError( + connection_id=str(connection_id), + detail="Cannot refresh an inactive connection. Create a new connection to re-establish authorization.", + ) + + # Callback URL is server-owned with a signed state token. + state = make_oauth_state( + project_id=project_id, + user_id=project_id, # refresh has no user_id; use project_id as entity + secret_key=env.agenta.crypt_key, + ) + callback_url = f"{env.agenta.api_url}{_CALLBACK_PATH}?state={state}" + + adapter = self.adapter_registry.get(conn.provider_key.value) + + # Delegate provider-specific refresh logic to the adapter. + # For OAuth providers (e.g. Composio), the adapter re-initiates the link. + provider_connection_id = conn.provider_connection_id + result = await adapter.refresh_connection( + provider_connection_id=conn.provider_connection_id, + force=force, + callback_url=callback_url, + integration_key=conn.integration_key, + user_id=str(project_id), + ) + provider_connection_id = result.get("id") or provider_connection_id + auth_config_id = result.get("auth_config_id") + is_valid = result.get("is_valid", conn.is_valid) + + redirect_url = result.get("redirect_url") + # Always overwrite redirect_url so FE doesn't reuse stale links from prior flows. + data_update = {"redirect_url": redirect_url} + if auth_config_id: + data_update["auth_config_id"] = auth_config_id + + updated = await self.connections_dao.update_connection( + project_id=project_id, + connection_id=connection_id, + is_valid=is_valid, + provider_connection_id=provider_connection_id, + data_update=data_update, + ) + + return updated or conn diff --git a/api/oss/src/core/tools/utils.py b/api/oss/src/core/connections/utils.py similarity index 96% rename from api/oss/src/core/tools/utils.py rename to api/oss/src/core/connections/utils.py index 79334acd55..58a3dd18b5 100644 --- a/api/oss/src/core/tools/utils.py +++ b/api/oss/src/core/connections/utils.py @@ -1,4 +1,4 @@ -"""OAuth state signing utilities for tool connection callbacks.""" +"""OAuth state signing utilities for connection callbacks.""" import base64 import hashlib diff --git a/api/oss/src/core/tools/dtos.py b/api/oss/src/core/tools/dtos.py index a588965f61..2c1ac2bf82 100644 --- a/api/oss/src/core/tools/dtos.py +++ b/api/oss/src/core/tools/dtos.py @@ -5,11 +5,7 @@ from pydantic import BaseModel from oss.src.core.shared.dtos import ( - Header, Identifier, - Lifecycle, - Metadata, - Slug, Json, Status, ) @@ -85,71 +81,6 @@ class ToolCatalogProviderDetails(ToolCatalogProvider): integrations: Optional[List[ToolCatalogIntegration]] = None -# --------------------------------------------------------------------------- -# Tool Connections -# --------------------------------------------------------------------------- - - -class ToolConnectionStatus(BaseModel): - redirect_url: Optional[str] = None - - -class ToolConnectionCreateData(BaseModel): - callback_url: Optional[str] = None - # - auth_scheme: Optional[ToolAuthScheme] = None - - -class ToolConnection( - Identifier, - Slug, - Header, - Lifecycle, - Metadata, -): - provider_key: ToolProviderKind - integration_key: str - # - data: Optional[Json] = None - # - status: Optional[ToolConnectionStatus] = None - - @property - def provider_connection_id(self) -> Optional[str]: - """Get provider-specific connection ID from data.""" - if self.data and isinstance(self.data, dict): - # For Composio, it's stored as "connected_account_id" - return self.data.get("connected_account_id") or self.data.get( - "provider_connection_id" - ) - return None - - @property - def is_active(self) -> bool: - """Check if connection is active (not deleted).""" - if self.flags and isinstance(self.flags, dict): - return self.flags.get("is_active", False) - return False - - @property - def is_valid(self) -> bool: - """Check if connection is valid (authenticated).""" - if self.flags and isinstance(self.flags, dict): - return self.flags.get("is_valid", False) - return False - - -class ToolConnectionCreate( - Slug, - Header, - Metadata, -): - provider_key: ToolProviderKind - integration_key: str - # - data: Optional[ToolConnectionCreateData] = None - - # --------------------------------------------------------------------------- # Tool Calls # --------------------------------------------------------------------------- @@ -191,32 +122,6 @@ class ToolResult(Identifier): data: Optional[ToolResultData] = None -# --------------------------------------------------------------------------- -# Tool Connection (adapter-level DTOs) -# --------------------------------------------------------------------------- - - -class ToolConnectionRequest(BaseModel): - """Input DTO for initiating a provider connection via a gateway adapter.""" - - user_id: str - integration_key: str - auth_scheme: Optional[str] = None - callback_url: Optional[str] = None - - -class ToolConnectionResponse(BaseModel): - """Output DTO from ToolsGatewayInterface.initiate_connection. - - The adapter builds ``connection_data`` with provider-specific fields so the - service never needs to know which provider it is talking to. - """ - - provider_connection_id: str - redirect_url: Optional[str] = None - connection_data: Dict[str, Any] = {} - - # --------------------------------------------------------------------------- # Tool Execution (adapter-level DTOs) # --------------------------------------------------------------------------- diff --git a/api/oss/src/core/tools/interfaces.py b/api/oss/src/core/tools/interfaces.py index fdf0a820f7..0e459619e6 100644 --- a/api/oss/src/core/tools/interfaces.py +++ b/api/oss/src/core/tools/interfaces.py @@ -1,95 +1,25 @@ -from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple -from uuid import UUID - -from oss.src.core.tools.dtos import ( - ToolCatalogAction, - ToolCatalogActionDetails, - ToolCatalogIntegration, - ToolCatalogProvider, - ToolConnection, - ToolConnectionCreate, - ToolConnectionRequest, - ToolConnectionResponse, - ToolExecutionRequest, - ToolExecutionResponse, -) - - -class ToolsDAOInterface(ABC): - """Connection persistence contract.""" - - @abstractmethod - async def create_connection( - self, - *, - project_id: UUID, - user_id: UUID, - # - connection_create: ToolConnectionCreate, - ) -> Optional[ToolConnection]: ... - - @abstractmethod - async def get_connection( - self, - *, - project_id: UUID, - connection_id: UUID, - ) -> Optional[ToolConnection]: ... - - @abstractmethod - async def update_connection( - self, - *, - project_id: UUID, - connection_id: UUID, - # - is_valid: Optional[bool] = None, - is_active: Optional[bool] = None, - provider_connection_id: Optional[str] = None, - data_update: Optional[Dict[str, Any]] = None, - ) -> Optional[ToolConnection]: ... - - @abstractmethod - async def delete_connection( - self, - *, - project_id: UUID, - connection_id: UUID, - ) -> bool: ... - - @abstractmethod - async def query_connections( - self, - *, - project_id: UUID, - # - provider_key: Optional[str] = None, - integration_key: Optional[str] = None, - is_active: Optional[bool] = True, - ) -> List[ToolConnection]: ... - - @abstractmethod - async def find_connection_by_provider_id( - self, - *, - provider_connection_id: str, - ) -> Optional[ToolConnection]: ... - - @abstractmethod - async def activate_connection_by_provider_id( - self, - *, - provider_connection_id: str, - project_id: Optional[UUID] = None, - ) -> Optional[ToolConnection]: ... - - -class ToolsGatewayInterface(ABC): - """Port for external tool providers (Composio, Agenta, etc.).""" - - @abstractmethod - async def list_providers(self) -> List[ToolCatalogProvider]: ... +from abc import ABC, abstractmethod +from typing import List, Optional, Tuple + +from oss.src.core.tools.dtos import ( + ToolCatalogAction, + ToolCatalogActionDetails, + ToolCatalogIntegration, + ToolCatalogProvider, + ToolExecutionRequest, + ToolExecutionResponse, +) + + +class ToolsGatewayInterface(ABC): + """Port for external tool providers (Composio, Agenta, etc.). + + Tool-specific verbs only — catalog browse and execution. Connection auth + verbs live behind ``ConnectionsGatewayInterface`` in the connections domain. + """ + + @abstractmethod + async def list_providers(self) -> List[ToolCatalogProvider]: ... @abstractmethod async def list_integrations( @@ -129,51 +59,12 @@ async def get_action( self, *, integration_key: str, - action_key: str, - ) -> Optional[ToolCatalogActionDetails]: ... - - @abstractmethod - async def initiate_connection( - self, - *, - request: ToolConnectionRequest, - ) -> ToolConnectionResponse: - """Initiate a provider-side connection. Returns a typed response with - provider_connection_id, redirect_url, and connection_data — the dict - the service will persist in the local connection record. - """ - ... - - @abstractmethod - async def get_connection_status( - self, - *, - provider_connection_id: str, - ) -> Dict[str, Any]: - """Poll provider for updated connection status.""" - ... - - @abstractmethod - async def refresh_connection( - self, - *, - provider_connection_id: str, - force: bool = False, - callback_url: Optional[str] = None, - integration_key: Optional[str] = None, - user_id: Optional[str] = None, - ) -> Dict[str, Any]: ... - - @abstractmethod - async def revoke_connection( - self, - *, - provider_connection_id: str, - ) -> bool: ... - - @abstractmethod - async def execute( - self, + action_key: str, + ) -> Optional[ToolCatalogActionDetails]: ... + + @abstractmethod + async def execute( + self, *, request: ToolExecutionRequest, ) -> ToolExecutionResponse: diff --git a/api/oss/src/core/tools/providers/composio/adapter.py b/api/oss/src/core/tools/providers/composio/adapter.py index f90ab9aa8e..82dfb56e83 100644 --- a/api/oss/src/core/tools/providers/composio/adapter.py +++ b/api/oss/src/core/tools/providers/composio/adapter.py @@ -9,8 +9,6 @@ from oss.src.core.tools.dtos import ( ToolCatalogActionDetails, ToolCatalogProvider, - ToolConnectionRequest, - ToolConnectionResponse, ToolExecutionRequest, ToolExecutionResponse, ) @@ -28,8 +26,8 @@ class ComposioToolsAdapter(ComposioCatalogClient, ToolsGatewayInterface): """Composio V3 API adapter — uses httpx directly (no SDK). Catalog operations (list/get integrations and actions) are provided by - ``ComposioCatalogClient``. Connection management and tool execution are - implemented here. + ``ComposioCatalogClient``. Tool execution is implemented here. Connection + auth lives in ``ComposioConnectionsAdapter``. """ def __init__( @@ -89,14 +87,6 @@ async def _post( resp.raise_for_status() return resp.json() - async def _delete(self, path: str) -> bool: - resp = await self._client.delete( - f"{self.api_url}{path}", - headers=self._headers(), - ) - resp.raise_for_status() - return True - # ----------------------------------------------------------------------- # Catalog — provider listing # ----------------------------------------------------------------------- @@ -163,217 +153,6 @@ async def get_action( scopes=item.get("scopes") or None, ) - # ----------------------------------------------------------------------- - # Connections - # ----------------------------------------------------------------------- - - async def initiate_connection( - self, - *, - request: ToolConnectionRequest, - ) -> ToolConnectionResponse: - user_id = request.user_id - integration_key = request.integration_key - auth_scheme = request.auth_scheme - callback_url = request.callback_url - - # Step 1: validate the toolkit exists and get its auth scheme info. - try: - toolkit = await self._get(f"/toolkits/{integration_key}") - except httpx.HTTPStatusError as e: - if e.response.status_code == 404: - raise AdapterError( - provider_key="composio", - operation="initiate_connection.validate_toolkit", - detail=f"Integration '{integration_key}' not found", - ) from e - raise AdapterError( - provider_key="composio", - operation="initiate_connection.validate_toolkit", - detail=str(e), - ) from e - except httpx.HTTPError as e: - raise AdapterError( - provider_key="composio", - operation="initiate_connection.validate_toolkit", - detail=str(e), - ) from e - - # Step 2: create an auth config for this integration. - # api_key → use_custom_auth; Composio's redirect UI collects the credentials. - # oauth / None → use_composio_managed_auth. - log.info( - "initiate_connection: integration_key=%s auth_scheme=%r", - integration_key, - auth_scheme, - ) - - if auth_scheme == "api_key": - # Derive Composio authScheme from toolkit's auth_config_details. - # Fall back to "API_KEY" as the common default. - composio_auth_scheme = "API_KEY" - for detail in toolkit.get("auth_config_details") or []: - mode = detail.get("mode", "") - if mode and "oauth" not in mode.lower(): - composio_auth_scheme = mode - break - - auth_config_body: Dict[str, Any] = { - "type": "use_custom_auth", - "authScheme": composio_auth_scheme, - } - else: - auth_config_body = {"type": "use_composio_managed_auth"} - - auth_configs_payload = { - "toolkit": {"slug": integration_key}, - "auth_config": auth_config_body, - } - log.info( - "initiate_connection: POST /auth_configs payload=%s", auth_configs_payload - ) - - try: - auth_config_result = await self._post( - "/auth_configs", - json=auth_configs_payload, - ) - except httpx.HTTPError as e: - raise AdapterError( - provider_key="composio", - operation="initiate_connection.create_auth_config", - detail=str(e), - ) from e - - auth_config_id = (auth_config_result.get("auth_config") or {}).get("id") - if not auth_config_id: - raise AdapterError( - provider_key="composio", - operation="initiate_connection.create_auth_config", - detail=f"No auth_config_id in response for integration '{integration_key}'", - ) - - log.info( - "initiate_connection: integration_key=%s auth_config_id=%s", - integration_key, - auth_config_id, - ) - - # Step 3: initiate connected account link. - payload: Dict[str, Any] = { - "user_id": user_id, - "auth_config_id": auth_config_id, - } - if callback_url: - payload["callback_url"] = callback_url - - try: - result = await self._post("/connected_accounts/link", json=payload) - except httpx.HTTPError as e: - raise AdapterError( - provider_key="composio", - operation="initiate_connection", - detail=str(e), - ) from e - - provider_connection_id = result.get("connected_account_id", "") - redirect_url = result.get("redirect_url") - - connection_data: Dict[str, Any] = { - "connected_account_id": provider_connection_id, - "auth_config_id": auth_config_id, - } - if redirect_url: - connection_data["redirect_url"] = redirect_url - - return ToolConnectionResponse( - provider_connection_id=provider_connection_id, - redirect_url=redirect_url, - connection_data=connection_data, - ) - - async def get_connection_status( - self, - *, - provider_connection_id: str, - ) -> Dict[str, Any]: - try: - result = await self._get(f"/connected_accounts/{provider_connection_id}") - except httpx.HTTPError as e: - raise AdapterError( - provider_key="composio", - operation="get_connection_status", - detail=str(e), - ) from e - - return { - "status": result.get("status"), - "is_valid": result.get("status") == "ACTIVE", - } - - async def refresh_connection( - self, - *, - provider_connection_id: str, - force: bool = False, - callback_url: Optional[str] = None, - integration_key: Optional[str] = None, - user_id: Optional[str] = None, - ) -> Dict[str, Any]: - # For Composio OAuth flows, "refresh" means re-initiating the auth link. - # The provider does not expose a token-refresh endpoint for OAuth connections, - # so we create a new connected_accounts/link which the user must re-authorize. - if integration_key and user_id: - result = await self.initiate_connection( - request=ToolConnectionRequest( - user_id=user_id, - integration_key=integration_key, - callback_url=callback_url, - ), - ) - return { - "id": result.provider_connection_id, - "redirect_url": result.redirect_url, - "auth_config_id": result.connection_data.get("auth_config_id"), - "is_valid": False, # Re-auth pending until callback fires - } - - payload: Dict[str, Any] = {} - if callback_url: - payload["callback_url"] = callback_url - - try: - result = await self._post( - f"/connected_accounts/{provider_connection_id}/refresh", - json=payload, - ) - except httpx.HTTPError as e: - raise AdapterError( - provider_key="composio", - operation="refresh_connection", - detail=str(e), - ) from e - - return { - "status": result.get("status"), - "is_valid": result.get("status") == "ACTIVE", - "redirect_url": result.get("redirect_url"), - } - - async def revoke_connection( - self, - *, - provider_connection_id: str, - ) -> bool: - try: - return await self._delete(f"/connected_accounts/{provider_connection_id}") - except httpx.HTTPError as e: - raise AdapterError( - provider_key="composio", - operation="revoke_connection", - detail=str(e), - ) from e - # ----------------------------------------------------------------------- # Execution # ----------------------------------------------------------------------- diff --git a/api/oss/src/core/tools/service.py b/api/oss/src/core/tools/service.py index f603bc4d42..4b30ca6121 100644 --- a/api/oss/src/core/tools/service.py +++ b/api/oss/src/core/tools/service.py @@ -1,45 +1,36 @@ from typing import Any, Dict, List, Optional, Tuple -from uuid import UUID - -from oss.src.utils.logging import get_module_logger -from oss.src.utils.env import env -from oss.src.core.tools.utils import make_oauth_state - -from oss.src.core.tools.dtos import ( - ToolCatalogAction, - ToolCatalogActionDetails, - ToolCatalogIntegration, - ToolCatalogProvider, - ToolConnection, - ToolConnectionCreate, - ToolConnectionRequest, - ToolExecutionRequest, - ToolExecutionResponse, -) -from oss.src.core.tools.interfaces import ( - ToolsDAOInterface, -) -from oss.src.core.tools.registry import ToolsGatewayRegistry -from oss.src.core.tools.exceptions import ( - ConnectionInactiveError, - ConnectionNotFoundError, -) - - -log = get_module_logger(__name__) +from uuid import UUID + +from oss.src.utils.logging import get_module_logger + +from oss.src.core.connections.dtos import Connection, ConnectionCreate +from oss.src.core.connections.service import ConnectionsService + +from oss.src.core.tools.dtos import ( + ToolCatalogAction, + ToolCatalogActionDetails, + ToolCatalogIntegration, + ToolCatalogProvider, + ToolExecutionRequest, + ToolExecutionResponse, +) +from oss.src.core.tools.registry import ToolsGatewayRegistry + + +log = get_module_logger(__name__) class ToolsService: - def __init__( - self, - *, - tools_dao: ToolsDAOInterface, - adapter_registry: ToolsGatewayRegistry, - ): - self.tools_dao = tools_dao - self.adapter_registry = adapter_registry - - # ----------------------------------------------------------------------- + def __init__( + self, + *, + connections_service: ConnectionsService, + adapter_registry: ToolsGatewayRegistry, + ): + self.connections_service = connections_service + self.adapter_registry = adapter_registry + + # ----------------------------------------------------------------------- # Catalog browse # ----------------------------------------------------------------------- @@ -129,261 +120,125 @@ async def get_action( return await adapter.get_action( integration_key=integration_key, action_key=action_key, - ) - - # ----------------------------------------------------------------------- - # Connection management - # ----------------------------------------------------------------------- - - async def query_connections( + ) + + # ----------------------------------------------------------------------- + # Connection management (delegated to ConnectionsService — one-way dep) + # ----------------------------------------------------------------------- + + async def query_connections( self, *, project_id: UUID, # - provider_key: Optional[str] = None, - integration_key: Optional[str] = None, - is_active: Optional[bool] = True, - ) -> List[ToolConnection]: - """Query connections with optional filtering. Defaults to active-only.""" - return await self.tools_dao.query_connections( - project_id=project_id, - provider_key=provider_key, - integration_key=integration_key, - is_active=is_active, - ) - - async def find_connection_by_provider_connection_id( - self, - *, - provider_connection_id: str, - ) -> Optional[ToolConnection]: - """Find any connection by its provider-side ID (for OAuth callbacks).""" - return await self.tools_dao.find_connection_by_provider_id( - provider_connection_id=provider_connection_id, - ) - - async def activate_connection_by_provider_connection_id( - self, - *, - provider_connection_id: str, - project_id: Optional[UUID] = None, - ) -> Optional[ToolConnection]: - """Mark a connection valid+active after OAuth completes.""" - return await self.tools_dao.activate_connection_by_provider_id( - provider_connection_id=provider_connection_id, - project_id=project_id, - ) - - async def list_connections( - self, - *, - project_id: UUID, - provider_key: str, - integration_key: str, - ) -> List[ToolConnection]: - """List connections for a specific integration (catalog enrichment).""" - return await self.tools_dao.query_connections( - project_id=project_id, - provider_key=provider_key, - integration_key=integration_key, + provider_key: Optional[str] = None, + integration_key: Optional[str] = None, + is_active: Optional[bool] = True, + ) -> List[Connection]: + return await self.connections_service.query_connections( + project_id=project_id, + provider_key=provider_key, + integration_key=integration_key, + is_active=is_active, + ) + + async def list_connections( + self, + *, + project_id: UUID, + provider_key: str, + integration_key: str, + ) -> List[Connection]: + return await self.connections_service.list_connections( + project_id=project_id, + provider_key=provider_key, + integration_key=integration_key, ) async def get_connection( self, - *, - project_id: UUID, - connection_id: UUID, - ) -> Optional[ToolConnection]: - """Return a single connection by ID scoped to the project, or None.""" - # Read-only by design: do not mutate local state during GET. - return await self.tools_dao.get_connection( - project_id=project_id, - connection_id=connection_id, - ) - - async def create_connection( - self, - *, - project_id: UUID, - user_id: UUID, - # - connection_create: ToolConnectionCreate, - ) -> ToolConnection: - """Initiate a provider connection and persist it locally in pending state.""" - provider_key = connection_create.provider_key.value - integration_key = connection_create.integration_key - - adapter = self.adapter_registry.get(provider_key) - - # Callback URL is server-owned. Do not trust/require client-provided values. - # Embed a signed state token so the callback can scope the activation. - state = make_oauth_state( - project_id=project_id, - user_id=user_id, - secret_key=env.agenta.crypt_key, - ) - callback_url = f"{env.agenta.api_url}/tools/connections/callback?state={state}" - - # Initiate with provider - connection_create_data = connection_create.data - provider_result = await adapter.initiate_connection( - request=ToolConnectionRequest( - user_id=str(project_id), - integration_key=integration_key, - auth_scheme=connection_create_data.auth_scheme.value - if connection_create_data and connection_create_data.auth_scheme - else None, - callback_url=callback_url, - ), - ) - - # Merge provider-returned connection_data with service-level project_id. - # The adapter owns provider-specific field names; the service adds project scope. - data: Dict[str, Any] = dict(provider_result.connection_data) - data["project_id"] = str(project_id) - connection_create.data = data # type: ignore[assignment] - - # Persist locally - return await self.tools_dao.create_connection( - project_id=project_id, - user_id=user_id, - # + *, + project_id: UUID, + connection_id: UUID, + ) -> Optional[Connection]: + return await self.connections_service.get_connection( + project_id=project_id, + connection_id=connection_id, + ) + + async def find_connection_by_provider_connection_id( + self, + *, + provider_connection_id: str, + ) -> Optional[Connection]: + return await self.connections_service.find_connection_by_provider_connection_id( + provider_connection_id=provider_connection_id, + ) + + async def activate_connection_by_provider_connection_id( + self, + *, + provider_connection_id: str, + project_id: Optional[UUID] = None, + ) -> Optional[Connection]: + return await self.connections_service.activate_connection_by_provider_connection_id( + provider_connection_id=provider_connection_id, + project_id=project_id, + ) + + async def create_connection( + self, + *, + project_id: UUID, + user_id: UUID, + # + connection_create: ConnectionCreate, + ) -> Connection: + return await self.connections_service.initiate_connection( + project_id=project_id, + user_id=user_id, + # connection_create=connection_create, ) async def delete_connection( self, *, - project_id: UUID, - connection_id: UUID, - ) -> bool: - """Revoke provider-side connection and delete locally. Raises ConnectionNotFoundError if missing.""" - # Look up connection - conn = await self.tools_dao.get_connection( - project_id=project_id, - connection_id=connection_id, - ) - - if not conn: - raise ConnectionNotFoundError( - connection_id=str(connection_id), - ) - - # Revoke provider-side - if conn.provider_connection_id: - adapter = self.adapter_registry.get(conn.provider_key.value) - try: - await adapter.revoke_connection( - provider_connection_id=conn.provider_connection_id, - ) - except Exception: - log.warning( - "Failed to revoke provider connection %s, proceeding with local delete", - conn.provider_connection_id, - ) - - # Delete locally - return await self.tools_dao.delete_connection( - project_id=project_id, - connection_id=connection_id, - ) + project_id: UUID, + connection_id: UUID, + ) -> bool: + return await self.connections_service.delete_connection( + project_id=project_id, + connection_id=connection_id, + ) async def revoke_connection( self, - *, + *, + project_id: UUID, + connection_id: UUID, + ) -> Connection: + return await self.connections_service.revoke_connection( + project_id=project_id, + connection_id=connection_id, + ) + + async def refresh_connection( + self, + *, project_id: UUID, - connection_id: UUID, - ) -> ToolConnection: - """Mark a connection invalid locally without touching the provider.""" - conn = await self.tools_dao.get_connection( - project_id=project_id, - connection_id=connection_id, - ) - - if not conn: - raise ConnectionNotFoundError( - connection_id=str(connection_id), - ) - - updated = await self.tools_dao.update_connection( - project_id=project_id, - connection_id=connection_id, - is_valid=False, - ) - - return updated or conn - - async def refresh_connection( - self, - *, - project_id: UUID, - connection_id: UUID, - # - force: bool = False, - ) -> ToolConnection: - conn = await self.tools_dao.get_connection( - project_id=project_id, - connection_id=connection_id, - ) - - if not conn: - raise ConnectionNotFoundError( - connection_id=str(connection_id), - ) - - if not conn.provider_connection_id: - raise ConnectionNotFoundError( - connection_id=str(connection_id), - ) - - if not conn.is_active: - raise ConnectionInactiveError( - connection_id=str(connection_id), - detail="Cannot refresh an inactive connection. Create a new connection to re-establish authorization.", - ) - - # Callback URL is server-owned with a signed state token. - state = make_oauth_state( - project_id=project_id, - user_id=project_id, # refresh has no user_id; use project_id as entity - secret_key=env.agenta.crypt_key, - ) - callback_url = f"{env.agenta.api_url}/tools/connections/callback?state={state}" - - adapter = self.adapter_registry.get(conn.provider_key.value) - - # Delegate provider-specific refresh logic to the adapter. - # For OAuth providers (e.g. Composio), the adapter re-initiates the link. - provider_connection_id = conn.provider_connection_id - result = await adapter.refresh_connection( - provider_connection_id=conn.provider_connection_id, - force=force, - callback_url=callback_url, - integration_key=conn.integration_key, - user_id=str(project_id), - ) - provider_connection_id = result.get("id") or provider_connection_id - auth_config_id = result.get("auth_config_id") - is_valid = result.get("is_valid", conn.is_valid) - - redirect_url = result.get("redirect_url") - # Always overwrite redirect_url so FE doesn't reuse stale links from prior flows. - data_update = {"redirect_url": redirect_url} - if auth_config_id: - data_update["auth_config_id"] = auth_config_id - - updated = await self.tools_dao.update_connection( - project_id=project_id, - connection_id=connection_id, - is_valid=is_valid, - provider_connection_id=provider_connection_id, - data_update=data_update, - ) - - return updated or conn - - # ----------------------------------------------------------------------- - # Tool execution + connection_id: UUID, + # + force: bool = False, + ) -> Connection: + return await self.connections_service.refresh_connection( + project_id=project_id, + connection_id=connection_id, + force=force, + ) + + # ----------------------------------------------------------------------- + # Tool execution # ----------------------------------------------------------------------- async def execute_tool( diff --git a/api/oss/src/dbs/postgres/connections/__init__.py b/api/oss/src/dbs/postgres/connections/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/oss/src/dbs/postgres/connections/dao.py b/api/oss/src/dbs/postgres/connections/dao.py new file mode 100644 index 0000000000..7be2165ef7 --- /dev/null +++ b/api/oss/src/dbs/postgres/connections/dao.py @@ -0,0 +1,282 @@ +from typing import List, Optional +from datetime import datetime, timezone +from uuid import UUID + +from sqlalchemy import select, delete +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm.attributes import flag_modified + +from oss.src.utils.logging import get_module_logger +from oss.src.utils.exceptions import suppress_exceptions + +from oss.src.core.shared.exceptions import EntityCreationConflict +from oss.src.core.connections.interfaces import ConnectionsDAOInterface +from oss.src.core.connections.dtos import ( + Connection, + ConnectionCreate, +) + +from oss.src.dbs.postgres.shared.engine import ( + TransactionsEngine, + get_transactions_engine, +) +from oss.src.dbs.postgres.connections.dbes import ConnectionDBE +from oss.src.dbs.postgres.connections.mappings import ( + map_connection_create_to_dbe, + map_connection_dbe_to_dto, +) + + +log = get_module_logger(__name__) + + +class ConnectionsDAO(ConnectionsDAOInterface): + def __init__( + self, + *, + ConnectionDBE: type = ConnectionDBE, + engine: TransactionsEngine = None, + ): + self.ConnectionDBE = ConnectionDBE + if engine is None: + engine = get_transactions_engine() + self.engine = engine + + @suppress_exceptions(exclude=[EntityCreationConflict]) + async def create_connection( + self, + *, + project_id: UUID, + user_id: UUID, + # + connection_create: ConnectionCreate, + ) -> Optional[Connection]: + """Insert a new connection row. Raises EntityCreationConflict on slug collision.""" + dbe = map_connection_create_to_dbe( + project_id=project_id, + user_id=user_id, + # + dto=connection_create, + ) + + try: + async with self.engine.session() as session: + session.add(dbe) + await session.commit() + await session.refresh(dbe) + + return map_connection_dbe_to_dto(dbe=dbe) + + except IntegrityError as e: + error_str = str(e.orig) if e.orig else str(e) + if "uq_gateway_connections_project_provider_integration_slug" in error_str: + raise EntityCreationConflict( + entity="Connection", + message="Connection with slug '{{slug}}' already exists for this integration.".replace( + "{{slug}}", connection_create.slug + ), + conflict={ + "provider_key": connection_create.provider_key, + "integration_key": connection_create.integration_key, + "slug": connection_create.slug, + }, + ) from e + raise + + @suppress_exceptions(default=None) + async def get_connection( + self, + *, + project_id: UUID, + connection_id: UUID, + ) -> Optional[Connection]: + """Fetch a connection by ID scoped to project_id. Returns None if not found.""" + async with self.engine.session() as session: + stmt = ( + select(self.ConnectionDBE) + .filter(self.ConnectionDBE.project_id == project_id) + .filter(self.ConnectionDBE.id == connection_id) + .limit(1) + ) + + result = await session.execute(stmt) + dbe = result.scalars().first() + + if not dbe: + return None + + return map_connection_dbe_to_dto(dbe=dbe) + + @suppress_exceptions(default=None) + async def update_connection( + self, + *, + project_id: UUID, + connection_id: UUID, + # + is_valid: Optional[bool] = None, + is_active: Optional[bool] = None, + provider_connection_id: Optional[str] = None, + data_update: Optional[dict] = None, + ) -> Optional[Connection]: + """Partially update flags and/or data for a connection. Returns updated DTO or None.""" + async with self.engine.session() as session: + stmt = ( + select(self.ConnectionDBE) + .filter(self.ConnectionDBE.project_id == project_id) + .filter(self.ConnectionDBE.id == connection_id) + .limit(1) + ) + + result = await session.execute(stmt) + dbe = result.scalars().first() + + if not dbe: + return None + + # Update flags + if is_valid is not None or is_active is not None: + flags = {**(dbe.flags or {})} + if is_valid is not None: + flags["is_valid"] = is_valid + if is_active is not None: + flags["is_active"] = is_active + dbe.flags = flags + flag_modified(dbe, "flags") + + # Update data fields + data_patch: dict = {} + if provider_connection_id is not None: + data_patch["connected_account_id"] = provider_connection_id + if data_update: + data_patch.update(data_update) + if data_patch: + dbe.data = {**(dbe.data or {}), **data_patch} + flag_modified(dbe, "data") + + dbe.updated_at = datetime.now(timezone.utc) + + await session.commit() + await session.refresh(dbe) + + return map_connection_dbe_to_dto(dbe=dbe) + + @suppress_exceptions(default=False) + async def delete_connection( + self, + *, + project_id: UUID, + connection_id: UUID, + ) -> bool: + """Hard-delete a connection row. Returns True if a row was deleted.""" + async with self.engine.session() as session: + stmt = ( + delete(self.ConnectionDBE) + .where(self.ConnectionDBE.project_id == project_id) + .where(self.ConnectionDBE.id == connection_id) + ) + + result = await session.execute(stmt) + await session.commit() + + return result.rowcount > 0 + + @suppress_exceptions(default=[]) + async def query_connections( + self, + *, + project_id: UUID, + # + provider_key: Optional[str] = None, + integration_key: Optional[str] = None, + is_active: Optional[bool] = True, + ) -> List[Connection]: + """List connections with optional filters. Defaults to active-only (is_active=True).""" + async with self.engine.session() as session: + stmt = select(self.ConnectionDBE).filter( + self.ConnectionDBE.project_id == project_id, + ) + + if provider_key: + stmt = stmt.filter(self.ConnectionDBE.provider_key == provider_key) + + if integration_key: + stmt = stmt.filter( + self.ConnectionDBE.integration_key == integration_key + ) + + if is_active is not None: + expected = "true" if is_active else "false" + stmt = stmt.filter( + self.ConnectionDBE.flags["is_active"].astext == expected + ) + + stmt = stmt.order_by(self.ConnectionDBE.created_at.desc()) + + result = await session.execute(stmt) + dbes = result.scalars().all() + + return [map_connection_dbe_to_dto(dbe=dbe) for dbe in dbes] + + @suppress_exceptions(default=None) + async def activate_connection_by_provider_id( + self, + *, + provider_connection_id: str, + project_id: Optional[UUID] = None, + ) -> Optional[Connection]: + """Set is_valid=True and is_active=True for the connection matching the provider ID.""" + async with self.engine.session() as session: + stmt = select(self.ConnectionDBE).filter( + self.ConnectionDBE.data["connected_account_id"].astext + == provider_connection_id + ) + + if project_id is not None: + stmt = stmt.filter(self.ConnectionDBE.project_id == project_id) + + stmt = stmt.limit(1) + + result = await session.execute(stmt) + dbe = result.scalars().first() + + if not dbe: + return None + + flags = {**(dbe.flags or {})} + flags["is_valid"] = True + flags["is_active"] = True + dbe.flags = flags + flag_modified(dbe, "flags") + + dbe.updated_at = datetime.now(timezone.utc) + + await session.commit() + await session.refresh(dbe) + + return map_connection_dbe_to_dto(dbe=dbe) + + @suppress_exceptions(default=None) + async def find_connection_by_provider_id( + self, + *, + provider_connection_id: str, + ) -> Optional[Connection]: + """Lookup any connection by provider-side connected_account_id (no project scope).""" + async with self.engine.session() as session: + stmt = ( + select(self.ConnectionDBE) + .filter( + self.ConnectionDBE.data["connected_account_id"].astext + == provider_connection_id + ) + .limit(1) + ) + + result = await session.execute(stmt) + dbe = result.scalars().first() + + if not dbe: + return None + + return map_connection_dbe_to_dto(dbe=dbe) diff --git a/api/oss/src/dbs/postgres/connections/dbes.py b/api/oss/src/dbs/postgres/connections/dbes.py new file mode 100644 index 0000000000..087f03e9b1 --- /dev/null +++ b/api/oss/src/dbs/postgres/connections/dbes.py @@ -0,0 +1,69 @@ +from sqlalchemy import ( + Column, + ForeignKeyConstraint, + Index, + PrimaryKeyConstraint, + String, + UniqueConstraint, +) + +from oss.src.dbs.postgres.shared.base import Base +from oss.src.dbs.postgres.shared.dbas import ( + DataDBA, + FlagsDBA, + HeaderDBA, + IdentifierDBA, + LifecycleDBA, + MetaDBA, + ProjectScopeDBA, + SlugDBA, + StatusDBA, + TagsDBA, +) + + +class ConnectionDBE( + Base, + ProjectScopeDBA, + IdentifierDBA, + SlugDBA, + LifecycleDBA, + HeaderDBA, + TagsDBA, + FlagsDBA, + DataDBA, + StatusDBA, + MetaDBA, +): + __tablename__ = "gateway_connections" + + __table_args__ = ( + PrimaryKeyConstraint("project_id", "id"), + UniqueConstraint( + "project_id", + "provider_key", + "integration_key", + "slug", + name="uq_gateway_connections_project_provider_integration_slug", + ), + ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + ondelete="CASCADE", + ), + Index( + "ix_gateway_connections_project_provider_integration", + "project_id", + "provider_key", + "integration_key", + ), + ) + + provider_key = Column( + String, + nullable=False, + ) + integration_key = Column( + String, + nullable=False, + ) diff --git a/api/oss/src/dbs/postgres/tools/mappings.py b/api/oss/src/dbs/postgres/connections/mappings.py similarity index 80% rename from api/oss/src/dbs/postgres/tools/mappings.py rename to api/oss/src/dbs/postgres/connections/mappings.py index 334fd600c0..a7036d44b1 100644 --- a/api/oss/src/dbs/postgres/tools/mappings.py +++ b/api/oss/src/dbs/postgres/connections/mappings.py @@ -2,12 +2,12 @@ from pydantic import BaseModel -from oss.src.core.tools.dtos import ( - ToolConnection, - ToolConnectionCreate, - ToolConnectionStatus, +from oss.src.core.connections.dtos import ( + Connection, + ConnectionCreate, + ConnectionStatus, ) -from oss.src.dbs.postgres.tools.dbes import ToolConnectionDBE +from oss.src.dbs.postgres.connections.dbes import ConnectionDBE def map_connection_create_to_dbe( @@ -15,8 +15,8 @@ def map_connection_create_to_dbe( project_id: UUID, user_id: UUID, # - dto: ToolConnectionCreate, -) -> ToolConnectionDBE: + dto: ConnectionCreate, +) -> ConnectionDBE: # Serialize provider-specific data to dict if present data = None if dto.data: @@ -30,7 +30,7 @@ def map_connection_create_to_dbe( flags.setdefault("is_active", True) flags.setdefault("is_valid", False) - return ToolConnectionDBE( + return ConnectionDBE( project_id=project_id, slug=dto.slug, name=dto.name, @@ -50,17 +50,17 @@ def map_connection_create_to_dbe( def map_connection_dbe_to_dto( *, - dbe: ToolConnectionDBE, -) -> ToolConnection: + dbe: ConnectionDBE, +) -> Connection: # Keep provider data generic in core DTOs. data = dbe.data or None # Parse status status = None if dbe.status: - status = ToolConnectionStatus(**dbe.status) + status = ConnectionStatus(**dbe.status) - return ToolConnection( + return Connection( id=dbe.id, slug=dbe.slug, name=dbe.name, diff --git a/api/oss/tests/pytest/acceptance/tools/test_tools_connections.py b/api/oss/tests/pytest/acceptance/tools/test_tools_connections.py new file mode 100644 index 0000000000..bd963b0d33 --- /dev/null +++ b/api/oss/tests/pytest/acceptance/tools/test_tools_connections.py @@ -0,0 +1,71 @@ +"""Acceptance tests for the /tools/connections contract (WP0). + +The connection now lives in the routerless ``connections`` domain backed by the +``gateway_connections`` table, but the public HTTP surface stays at +``/tools/connections`` byte-for-byte. These tests pin that contract. + +The query endpoint is DB-only — it needs no Composio credentials. A fresh +project returns an empty, well-shaped list, which also proves the table rename +landed (the query hits ``gateway_connections``). Create / refresh / revoke make +real provider calls, so those are gated on COMPOSIO_API_KEY. +""" + +import os +from uuid import uuid4 + +import pytest + + +_COMPOSIO_ENABLED = bool(os.getenv("COMPOSIO_API_KEY")) +_requires_composio = pytest.mark.skipif( + not _COMPOSIO_ENABLED, + reason="needs live Composio credentials (COMPOSIO_API_KEY)", +) + + +class TestToolsConnectionsQuery: + def test_query_connections_returns_200(self, authed_api): + response = authed_api("POST", "/tools/connections/query") + assert response.status_code == 200 + + def test_query_connections_response_shape(self, authed_api): + body = authed_api("POST", "/tools/connections/query").json() + assert "count" in body + assert "connections" in body + assert isinstance(body["connections"], list) + assert body["count"] == len(body["connections"]) + + +class TestToolsConnectionsGet: + def test_get_unknown_connection_returns_404(self, authed_api): + response = authed_api("GET", f"/tools/connections/{uuid4()}") + assert response.status_code == 404 + + +@_requires_composio +class TestToolsConnectionsLifecycle: + def test_create_revoke_roundtrip(self, authed_api): + slug = f"acc-{uuid4().hex[:8]}" + create = authed_api( + "POST", + "/tools/connections/", + json={ + "connection": { + "slug": slug, + "provider_key": "composio", + "integration_key": "github", + "data": {"auth_scheme": "oauth"}, + } + }, + ) + assert create.status_code == 200, create.text + connection = create.json()["connection"] + connection_id = connection["id"] + + # Local-only revoke (C7/B3): flips is_valid on the shared row, no + # provider call, no cascade. + revoke = authed_api("POST", f"/tools/connections/{connection_id}/revoke") + assert revoke.status_code == 200, revoke.text + assert revoke.json()["connection"]["flags"]["is_valid"] is False + + authed_api("DELETE", f"/tools/connections/{connection_id}") diff --git a/api/oss/tests/pytest/unit/models/test_lifecycle_conventions.py b/api/oss/tests/pytest/unit/models/test_lifecycle_conventions.py index 8e0399f3ec..ef15727bc8 100644 --- a/api/oss/tests/pytest/unit/models/test_lifecycle_conventions.py +++ b/api/oss/tests/pytest/unit/models/test_lifecycle_conventions.py @@ -16,7 +16,7 @@ "oss.src.dbs.postgres.users.dbes", "oss.src.dbs.postgres.folders.dbes", "oss.src.dbs.postgres.secrets.dbes", - "oss.src.dbs.postgres.tools.dbes", + "oss.src.dbs.postgres.connections.dbes", "oss.src.dbs.postgres.events.dbes", "oss.src.dbs.postgres.webhooks.dbes", "oss.src.dbs.postgres.tracing.dbes",