diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index f3751206a8..2710c3894c 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,7 +43,6 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib - import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index d857da9635..e31db15788 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,7 +25,6 @@ from absl import flags import experiment from google.genai import types - import utils _OUTPUT_DIR = flags.DEFINE_string( diff --git a/pyproject.toml b/pyproject.toml index 3a2a587b5d..15e6054c9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,6 +123,7 @@ test = [ "a2a-sdk>=0.3.0,<0.4.0", "anthropic>=0.43.0", # For anthropic model tests "crewai[tools];python_version>='3.11' and python_version<'3.12'", # For CrewaiTool tests; chromadb/pypika fail on 3.12+ + "google-cloud-iamconnectorcredentials>=0.1.0, <0.2.0", "google-cloud-parametermanager>=0.4.0, <1.0.0", "kubernetes>=29.0.0", # For GkeCodeExecutor "langchain-community>=0.3.17", @@ -176,6 +177,10 @@ toolbox = ["toolbox-adk>=1.0.0, <2.0.0"] slack = ["slack-bolt>=1.22.0"] +agent-identity = [ + "google-cloud-iamconnectorcredentials>=0.1.0, <0.2.0", +] + [tool.pyink] # Format py files following Google style-guide line-length = 80 diff --git a/src/google/adk/artifacts/file_artifact_service.py b/src/google/adk/artifacts/file_artifact_service.py index b0078e27ce..9c3870b6e3 100644 --- a/src/google/adk/artifacts/file_artifact_service.py +++ b/src/google/adk/artifacts/file_artifact_service.py @@ -138,6 +138,31 @@ def _is_user_scoped(session_id: Optional[str], filename: str) -> bool: return session_id is None or _file_has_user_namespace(filename) +def _validate_path_segment(value: str, field_name: str) -> None: + """Rejects values that could alter the constructed filesystem path. + + Args: + value: The caller-supplied identifier (e.g. user_id or session_id). + field_name: Human-readable name used in the error message. + + Raises: + InputValidationError: If the value contains path separators, traversal + segments, or null bytes. + """ + if not value: + raise InputValidationError(f"{field_name} must not be empty.") + if "\x00" in value: + raise InputValidationError(f"{field_name} must not contain null bytes.") + if "/" in value or "\\" in value: + raise InputValidationError( + f"{field_name} {value!r} must not contain path separators." + ) + if value in (".", "..") or ".." in value.split("/"): + raise InputValidationError( + f"{field_name} {value!r} must not contain traversal segments." + ) + + def _user_artifacts_dir(base_root: Path) -> Path: """Returns the path that stores user-scoped artifacts.""" return base_root / "artifacts" @@ -145,6 +170,7 @@ def _user_artifacts_dir(base_root: Path) -> Path: def _session_artifacts_dir(base_root: Path, session_id: str) -> Path: """Returns the path that stores session-scoped artifacts.""" + _validate_path_segment(session_id, "session_id") return base_root / "sessions" / session_id / "artifacts" @@ -220,6 +246,7 @@ def __init__(self, root_dir: Path | str): def _base_root(self, user_id: str, /) -> Path: """Returns the artifacts root directory for a user.""" + _validate_path_segment(user_id, "user_id") return self.root_dir / "users" / user_id def _scope_root( diff --git a/src/google/adk/cli/cli_deploy.py b/src/google/adk/cli/cli_deploy.py index e0c7f9b28c..ca06b1704f 100644 --- a/src/google/adk/cli/cli_deploy.py +++ b/src/google/adk/cli/cli_deploy.py @@ -808,6 +808,24 @@ def to_cloud_run( shutil.rmtree(temp_folder) +def _print_agent_engine_url(resource_name: str) -> None: + """Prints the Google Cloud Console URL for the deployed agent.""" + parts = resource_name.split('/') + if len(parts) >= 6 and parts[0] == 'projects' and parts[2] == 'locations': + project_id = parts[1] + region = parts[3] + engine_id = parts[5] + + url = ( + 'https://console.cloud.google.com/agent-platform/runtimes' + f'/locations/{region}/agent-engines/{engine_id}/playground' + f'?project={project_id}' + ) + click.secho( + f'\nšŸŽ‰ View your deployed agent here:\n{url}\n', fg='cyan', bold=True + ) + + def to_agent_engine( *, agent_folder: str, @@ -1150,11 +1168,13 @@ def to_agent_engine( f'āœ… Created agent engine: {agent_engine.api_resource.name}', fg='green', ) + _print_agent_engine_url(agent_engine.api_resource.name) else: if project and region and not agent_engine_id.startswith('projects/'): agent_engine_id = f'projects/{project}/locations/{region}/reasoningEngines/{agent_engine_id}' client.agent_engines.update(name=agent_engine_id, config=agent_config) click.secho(f'āœ… Updated agent engine: {agent_engine_id}', fg='green') + _print_agent_engine_url(agent_engine_id) finally: click.echo(f'Cleaning up the temp folder: {temp_folder}') shutil.rmtree(agent_src_path) diff --git a/src/google/adk/evaluation/evaluation_generator.py b/src/google/adk/evaluation/evaluation_generator.py index f8fb6795aa..6fac2decdd 100644 --- a/src/google/adk/evaluation/evaluation_generator.py +++ b/src/google/adk/evaluation/evaluation_generator.py @@ -281,7 +281,7 @@ def convert_events_to_eval_invocations( for invocation_id, events in events_by_invocation_id.items(): final_response = None final_event = None - user_content = Content(parts=[]) + user_content = None invocation_timestamp = 0 app_details = None if ( @@ -312,6 +312,18 @@ def convert_events_to_eval_invocations( events_to_add.append(event) break + if user_content is None: + # Skip invocations that have no user-authored event. Such invocations + # arise from internal/system-driven turns (e.g. background agent tasks) + # and are not meaningful for evaluation purposes. Including them would + # also cause a Pydantic ValidationError because Invocation.user_content + # requires a Content object. + logger.debug( + "Skipping invocation %s: no user-authored event found.", + invocation_id, + ) + continue + invocation_events = [ InvocationEvent(author=e.author, content=e.content) for e in events_to_add diff --git a/src/google/adk/integrations/agent_identity/README.md b/src/google/adk/integrations/agent_identity/README.md new file mode 100644 index 0000000000..26c8fe7fe9 --- /dev/null +++ b/src/google/adk/integrations/agent_identity/README.md @@ -0,0 +1,35 @@ +# GCP IAM Connector Auth + +Manages the complete lifecycle of an access token using the Google Cloud +Platform Agent Identity Credentials service. + +## Usage + +1. **Install Dependencies:** + ```bash + pip install "google-adk[agent-identity]" + ``` + +2. **Register the provider:** + Register the `GcpAuthProvider` with the `CredentialManager`. This is to be + done one time. + + ``` py + # user_agent_app.py + from google.adk.auth.credential_manager import CredentialManager + from google.adk.integrations.agent_identity import GcpAuthProvider + + CredentialManager.register_auth_provider(GcpAuthProvider()) + ``` + +3. **Configure the Auth provider:** + Specify the Agent Identity provider configuration using the + `GcpAuthProviderScheme`. + ``` py + # user_agent_app.py + from google.adk.integrations.agent_identity import GcpAuthProviderScheme + + # Configures Toolset + auth_scheme = GcpAuthProviderScheme(name="my-jira-auth_provider") + mcp_toolset_jira = McpToolset(..., auth_scheme=auth_scheme) + ``` diff --git a/src/google/adk/integrations/agent_identity/__init__.py b/src/google/adk/integrations/agent_identity/__init__.py new file mode 100644 index 0000000000..1025735236 --- /dev/null +++ b/src/google/adk/integrations/agent_identity/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .gcp_auth_provider import GcpAuthProvider +from .gcp_auth_provider_scheme import GcpAuthProviderScheme + +__all__ = [ + "GcpAuthProvider", + "GcpAuthProviderScheme", +] diff --git a/src/google/adk/integrations/agent_identity/gcp_auth_provider.py b/src/google/adk/integrations/agent_identity/gcp_auth_provider.py new file mode 100644 index 0000000000..1a1aaec8f8 --- /dev/null +++ b/src/google/adk/integrations/agent_identity/gcp_auth_provider.py @@ -0,0 +1,284 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import logging +import os +import time + +from google.adk.agents.callback_context import CallbackContext +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import HttpAuth +from google.adk.auth.auth_credential import HttpCredentials +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_tool import AuthConfig +from google.adk.auth.base_auth_provider import BaseAuthProvider +from google.adk.flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME +from google.api_core.client_options import ClientOptions +from google.cloud.iamconnectorcredentials_v1alpha import IAMConnectorCredentialsServiceClient as Client +from google.cloud.iamconnectorcredentials_v1alpha import RetrieveCredentialsMetadata +from google.cloud.iamconnectorcredentials_v1alpha import RetrieveCredentialsRequest +from google.cloud.iamconnectorcredentials_v1alpha import RetrieveCredentialsResponse +from google.longrunning.operations_pb2 import Operation +from typing_extensions import override + +from .gcp_auth_provider_scheme import GcpAuthProviderScheme + +# Notes on the current Agent Identity Credentials service implementation: +# 1. The service does not yet support LROs, so even though the +# retrieve_credentials method returns an Operation object, the methods like +# operation.done() and operation.result() will not work yet. +# 2. For API key flows, the returned Operation contains the credentials. +# 3. For 2-legged OAuth flows, the returned Operation contains pending status, +# client needs to retry the request until response with credentials is +# returned or timeout occurs. +# 4. For 3-legged OAuth flows, the returned Operation contains consent pending +# status along with the authorization URI. + +# TODO: Catch specific exceptions instead of generic ones. + +logger = logging.getLogger("google_adk." + __name__) + +NON_INTERACTIVE_TOKEN_POLL_INTERVAL_SEC: float = 1.0 +NON_INTERACTIVE_TOKEN_POLL_TIMEOUT_SEC: float = 10.0 + + +def _construct_auth_credential( + response: RetrieveCredentialsResponse, +) -> AuthCredential: + """Constructs a simplified HTTP auth credential from the header-token tuple returned by the upstream service.""" + if not response.header or not response.token: + raise ValueError( + "Received either empty header or token from Agent Identity Credentials" + " service." + ) + + header_name, _, header_value = response.header.partition(":") + if ( + header_name.strip().lower() == "authorization" + and header_value.strip().lower().startswith("bearer") + ): + return AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(token=response.token), + ), + ) + + # Handle custom header. + return AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + # For custom headers, scheme and credentials fields are not used. + scheme="", + credentials=HttpCredentials(), + additional_headers={ + response.header: response.token, + "X-GOOG-API-KEY": response.token, + }, + ), + ) + + +class GcpAuthProvider(BaseAuthProvider): + """An auth provider that uses the Agent Identity Credentials service to generate access tokens.""" + + _client: Client | None = None + + def __init__(self, client: Client | None = None): + self._client = client + + @property + @override + def supported_auth_schemes(self) -> tuple[type[GcpAuthProviderScheme], ...]: + return (GcpAuthProviderScheme,) + + def _get_client(self) -> Client: + """Lazy loads the client to avoid unnecessary setup on startup.""" + if self._client is None: + client_options = None + if host := os.environ.get("IAM_CONNECTOR_CREDENTIALS_TARGET_HOST"): + client_options = ClientOptions(api_endpoint=host) + self._client = Client(client_options=client_options, transport="rest") + return self._client + + async def _retrieve_credentials( + self, + user_id: str, + auth_scheme: GcpAuthProviderScheme, + ) -> Operation: + request = RetrieveCredentialsRequest( + connector=auth_scheme.name, + user_id=user_id, + scopes=auth_scheme.scopes, + continue_uri=auth_scheme.continue_uri or "", + force_refresh=False, + ) + # TODO: Use async client once available. Temporarily using threading to + # prevent blocking the event loop. + operation = await asyncio.to_thread( + self._get_client().retrieve_credentials, request + ) + return operation.operation + + def _unpack_operation( + self, operation: Operation + ) -> tuple[ + RetrieveCredentialsResponse | None, RetrieveCredentialsMetadata | None + ]: + """Deserializes the response and metadata from the operation.""" + response = None + metadata = None + if operation.response: + response = RetrieveCredentialsResponse.deserialize( + operation.response.value + ) + if operation.metadata: + metadata = RetrieveCredentialsMetadata.deserialize( + operation.metadata.value + ) + return response, metadata + + async def _poll_credentials( + self, user_id: str, auth_scheme: GcpAuthProviderScheme, timeout: float + ) -> Operation: + end_time = time.time() + timeout + while time.time() < end_time: + operation = await self._retrieve_credentials(user_id, auth_scheme) + if operation.done: + return operation + await asyncio.sleep(NON_INTERACTIVE_TOKEN_POLL_INTERVAL_SEC) + raise TimeoutError("Timeout waiting for credentials.") + + @staticmethod + def _is_consent_completed(context: CallbackContext) -> bool: + """Checks if the user consent flow is completed for the current function call.""" + if not context.function_call_id: + return False + + if not context.session: + return False + + events = context.session.events + target_tool_call_id = context.function_call_id + + # Find all relevant function calls and responses + euc_calls = {} + euc_responses = {} + + for event in events: + for call in event.get_function_calls(): + if call.name == REQUEST_EUC_FUNCTION_CALL_NAME: + euc_calls[call.id] = call + for response in event.get_function_responses(): + if response.name == REQUEST_EUC_FUNCTION_CALL_NAME: + euc_responses[response.id] = response + + # Check for a response that matches a call for the current tool invocation + for call_id, _ in euc_responses.items(): + if call_id in euc_calls: + call = euc_calls[call_id] + if ( + call.args + and call.args.get("function_call_id") == target_tool_call_id + ): + return True + return False + + @override + async def get_auth_credential( + self, + auth_config: AuthConfig, + context: CallbackContext | None = None, + ) -> AuthCredential: + """Retrieves credentials using the Agent Identity Credentials service. + + Args: + auth_config: The authentication configuration. + context: Optional context for the callback. + + Returns: + An AuthCredential instance. + + Raises: + ValueError: If auth_scheme is not a GcpAuthProviderScheme. + RuntimeError: If credential retrieval or polling fails. + """ + + auth_scheme = auth_config.auth_scheme + if not isinstance(auth_scheme, GcpAuthProviderScheme): + raise ValueError( + f"Expected GcpAuthProviderScheme, got {type(auth_scheme)}" + ) + + if context is None or context.user_id is None: + raise ValueError( + "GcpAuthProvider requires a context with a valid user_id." + ) + + user_id = context.user_id + + try: + operation = await self._retrieve_credentials(user_id, auth_scheme) + except Exception as e: + raise RuntimeError( + f"Failed to retrieve credential for user '{user_id}' on connector" + f" '{auth_scheme.name}'." + ) from e + + response, metadata = self._unpack_operation(operation) + + if operation.HasField("error"): + raise RuntimeError(f"Operation failed: {operation.error.message}") + + if operation.done: + logger.debug("Auth credential obtained immediately.") + return _construct_auth_credential(response) + + if metadata and metadata.consent_pending: + # Get 2-legged OAuth token. Allow enough time for token exchange. + try: + operation = await self._poll_credentials( + user_id, + auth_scheme, + timeout=NON_INTERACTIVE_TOKEN_POLL_TIMEOUT_SEC, + ) + if operation.HasField("error"): + raise RuntimeError(f"Operation failed: {operation.error.message}") + if operation.done: + logger.debug("Auth credential obtained after polling.") + response, _ = self._unpack_operation(operation) + return _construct_auth_credential(response) + except Exception as e: + raise RuntimeError( + f"Failed to retrieve credential for user '{user_id}' on connector" + f" '{auth_scheme.name}'." + ) from e + + if metadata is not None and metadata.uri_consent_required: + if self._is_consent_completed(context): + raise RuntimeError("Failed to retrieve consent based credential.") + + # Return AuthCredential with only auth_uri to trigger user consent flow. + return AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + auth_uri=metadata.uri_consent_required.authorization_uri, + nonce=metadata.uri_consent_required.consent_nonce, + ), + ) diff --git a/src/google/adk/integrations/agent_identity/gcp_auth_provider_scheme.py b/src/google/adk/integrations/agent_identity/gcp_auth_provider_scheme.py new file mode 100644 index 0000000000..e5ac769cca --- /dev/null +++ b/src/google/adk/integrations/agent_identity/gcp_auth_provider_scheme.py @@ -0,0 +1,48 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import List +from typing import Literal +from typing import Optional + +from google.adk.auth.auth_schemes import CustomAuthScheme +from pydantic import Field + + +class GcpAuthProviderScheme(CustomAuthScheme): + """The Agent Identity authentication scheme for Google Cloud Platform. + + Attributes: + name: The name of the GCP Auth Provider resource to use. + scopes: Optional. A list of OAuth2 scopes to request. + continue_uri: Optional. A type of redirect URI. It is distinct from the + standard OAuth2 redirect URI. Its purpose is to reauthenticate the user to + prevent phishing attacks and to finalize the managed OAuth flow. The + standard, Google-hosted OAuth2 redirect URI will redirect the user to this + continue URI. The agent will include this URI in every 3-legged OAuth + request sent to the upstream Agent Identity Credential service. Developers + must ensure this URI is hosted (e.g. on GCP, a third-party cloud, + on-prem), preferably alongside the agent client's web server. + TODO: Add public documentation link for more information once available. + type_: The type of the security scheme, always "gcpAuthProviderScheme". + """ + + type_: Literal["gcpAuthProviderScheme"] = Field( + default="gcpAuthProviderScheme", alias="type" + ) + name: str + scopes: Optional[List[str]] = None + continue_uri: Optional[str] = None diff --git a/src/google/adk/integrations/agent_registry/agent_registry.py b/src/google/adk/integrations/agent_registry/agent_registry.py index 9ff71cd9d5..62b0b53d03 100644 --- a/src/google/adk/integrations/agent_registry/agent_registry.py +++ b/src/google/adk/integrations/agent_registry/agent_registry.py @@ -36,6 +36,7 @@ from google.adk.agents.remote_a2a_agent import RemoteA2aAgent from google.adk.auth.auth_credential import AuthCredential from google.adk.auth.auth_schemes import AuthScheme +from google.adk.integrations.agent_identity.gcp_auth_provider_scheme import GcpAuthProviderScheme from google.adk.telemetry.tracing import GCP_MCP_SERVER_DESTINATION_ID from google.adk.tools.base_tool import BaseTool from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams @@ -292,8 +293,24 @@ def get_mcp_toolset( mcp_server_name: str, auth_scheme: AuthScheme | None = None, auth_credential: AuthCredential | None = None, + *, + continue_uri: str | None = None, ) -> McpToolset: - """Constructs an McpToolset instance from a registered MCP Server.""" + """Constructs an McpToolset from a registered MCP Server. + + If `auth_scheme` is omitted, it is automatically resolved from the server's + IAM bindings via `GcpAuthProviderScheme`. + + Args: + mcp_server_name: Resource name of the MCP Server. + auth_scheme: Optional auth scheme. Resolved via bindings if omitted. + auth_credential: Optional auth credential. + continue_uri: Optional continue URI to override what is in the auth + provider. + + Returns: + An McpToolset for the MCP server. + """ server_details = self.get_mcp_server(mcp_server_name) name = self._clean_name(server_details.get("displayName", mcp_server_name)) mcp_server_id = server_details.get("mcpServerId") @@ -313,6 +330,23 @@ def get_mcp_toolset( ) headers = self._get_auth_headers() if _is_google_api(endpoint_uri) else None + if mcp_server_id and not auth_scheme: + try: + bindings_data = self._make_request("bindings") + for b in bindings_data.get("bindings", []): + target_id = b.get("target", {}).get("identifier", "") + if target_id.endswith(mcp_server_id): + auth_provider = b.get("authProviderBinding", {}).get("authProvider") + if auth_provider: + auth_scheme = GcpAuthProviderScheme( + name=auth_provider, continue_uri=continue_uri + ) + break + except Exception as e: + logger.warning( + f"Failed to fetch bindings for MCP Server {mcp_server_name}: {e}" + ) + connection_params = StreamableHTTPConnectionParams( url=endpoint_uri, headers=headers, diff --git a/tests/integration/integrations/agent_identity/README.md b/tests/integration/integrations/agent_identity/README.md new file mode 100644 index 0000000000..a732b8717b --- /dev/null +++ b/tests/integration/integrations/agent_identity/README.md @@ -0,0 +1,26 @@ +# Integration tests for GCP Agent Identity Credentials service + +Verifies OAuth flows using GCP Agent Identity Credentials service. + +## Setup + +To set up your environment for the first time, run the `uv` setup script: +```bash +cd open_source_workspace +./uv_setup.sh +``` + +Then, activate the virtual environment: +```bash +source .venv/bin/activate +``` + +Then, install test specific packages +```bash +pip install google-cloud-iamconnectorcredentials +``` + +## Run Tests +```bash +pytest -s tests/integration/integrations/agent_identity +``` diff --git a/tests/integration/integrations/agent_identity/test_2lo_flow.py b/tests/integration/integrations/agent_identity/test_2lo_flow.py new file mode 100644 index 0000000000..7af790fef2 --- /dev/null +++ b/tests/integration/integrations/agent_identity/test_2lo_flow.py @@ -0,0 +1,202 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""E2E Integration Test for GCP Agent Identity Auth Provider two-legged OAuth Flow.""" + +import dataclasses +from typing import Any +from unittest import mock + +import pytest + +pytest.importorskip( + "google.cloud.iamconnectorcredentials_v1alpha", + reason="Requires google-cloud-iamconnectorcredentials", +) + +from google.adk import Agent +from google.adk import Runner +from google.adk.auth.auth_tool import AuthConfig +from google.adk.auth.credential_manager import CredentialManager +from google.adk.integrations.agent_identity import gcp_auth_provider +from google.adk.integrations.agent_identity import GcpAuthProvider +from google.adk.integrations.agent_identity import GcpAuthProviderScheme +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.tools.base_authenticated_tool import BaseAuthenticatedTool +from google.cloud.iamconnectorcredentials_v1alpha import RetrieveCredentialsRequest +from google.cloud.iamconnectorcredentials_v1alpha import RetrieveCredentialsResponse +from google.genai import types + +from tests.unittests import testing_utils + +DUMMY_TOKEN = "fake-gcp-2lo-token-123" +TEST_CONNECTOR_2LO = ( + "projects/test-project/locations/global/connectors/test-connector" +) + + +class DummyTool(BaseAuthenticatedTool): + + def __init__(self, auth_config: AuthConfig) -> None: + super().__init__( + name="dummy_tool", + description="Dummy tool for testing 2LO.", + auth_config=auth_config, + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type="OBJECT", + properties={}, + ), + ) + + async def _run_async_impl( + self, *, args: dict[str, Any] | None, tool_context: Any, credential: Any + ) -> Any: + # Return the token to prove the provider gave the expected credential + if credential.http and credential.http.credentials: + return credential.http.credentials.token + if credential.oauth2 and credential.oauth2.access_token: + return credential.oauth2.access_token + return None + + +# Mocked execution; pin to a single LLM backend to avoid duplicate runs. +@pytest.mark.parametrize("llm_backend", ["GOOGLE_AI"], indirect=True) +@dataclasses.dataclass +class _DummyOperation: + done: bool = True + error: Any = None + metadata: Any = None + response: Any = dataclasses.field(init=False) + operation: Any = dataclasses.field(init=False) + + def __post_init__(self) -> None: + self.response = mock.Mock() + mock_credential = RetrieveCredentialsResponse( + header="Authorization: Bearer", token=DUMMY_TOKEN + ) + self.response.value = RetrieveCredentialsResponse.serialize(mock_credential) + self.operation = self + + def HasField(self, field_name: str) -> bool: + return getattr(self, field_name, None) is not None + + +@pytest.mark.asyncio +async def test_gcp_agent_identity_2lo_gets_token() -> None: + """Test the end-to-end flow fetching 2LO OAuth token from GCP Agent Identity credentials service.""" + + # Clear registry to isolate tests + CredentialManager._auth_provider_registry._providers.clear() + + # 1. Setup mocked GCP Client to return the fake Bearer token + with mock.patch.object( + gcp_auth_provider, + "Client", + autospec=True, + ) as mock_client_cls: + + mock_operation = _DummyOperation() + + mock_client_cls.return_value.retrieve_credentials.return_value = ( + mock_operation + ) + + # 2. Configure Auth and DummyTool + auth_scheme = GcpAuthProviderScheme( + name=TEST_CONNECTOR_2LO, + scopes=["test-scope"], + ) + auth_config = AuthConfig(auth_scheme=auth_scheme) + dummy_tool = DummyTool(auth_config=auth_config) + + # 3. Setup LLM, Agent, and Runner + # We mock the LLM to just issue the tool call to 'dummy_tool' + mock_model = testing_utils.MockModel.create( + responses=[ + types.Part.from_function_call(name="dummy_tool", args={}), + "Tool executed successfully.", + ] + ) + + agent = Agent( + name="test_agent", + model=mock_model, + instruction="You are an agent. Use the dummy_tool when needed.", + tools=[dummy_tool], + ) + + runner = Runner( + app_name="test_mcp_2lo_app", + agent=agent, + session_service=InMemorySessionService(), + auto_create_session=True, + ) + + # 4. Register Auth Provider + CredentialManager.register_auth_provider(GcpAuthProvider()) + + # 5. Execute Flow + event_list = [] + async for event in runner.run_async( + user_id="test_user", + session_id="test_session1", + new_message=types.UserContent( + parts=[types.Part(text="Get me the token.")] + ), + ): + event_list.append(event) + + # 6. Assertions + + # Assert GCP Agent Identity client was invoked for credentials + expected_request = RetrieveCredentialsRequest( + connector=TEST_CONNECTOR_2LO, + user_id="test_user", + scopes=["test-scope"], + continue_uri="", + force_refresh=False, + ) + mock_client_cls.return_value.retrieve_credentials.assert_called_once_with( + expected_request + ) + + # 3 Events: Model FunctionCall -> Tool FunctionResponse -> Final LLM Text + assert len(event_list) == 3 + last_event = event_list[-1] + assert last_event.content.parts[0].text == "Tool executed successfully." + + # Validate that the mock model received the query and the tool callback + requests = mock_model.requests + # 2 Events: User Input -> Tool FunctionResponse + assert len(requests) == 2 + + # Extract the function response from the prompt payload sent to the LLM + last_request = requests[-1] + function_response = next( + ( + p.function_response + for p in last_request.contents[-1].parts + if p.function_response + ), + None, + ) + + assert function_response.name == "dummy_tool" + assert DUMMY_TOKEN in str(function_response.response) diff --git a/tests/integration/integrations/agent_identity/test_3lo_flow.py b/tests/integration/integrations/agent_identity/test_3lo_flow.py new file mode 100644 index 0000000000..2bf89e7ea6 --- /dev/null +++ b/tests/integration/integrations/agent_identity/test_3lo_flow.py @@ -0,0 +1,304 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""E2E Integration Test for 3LO flow using GCP Agent Identity service.""" + +import dataclasses +from typing import Any +from unittest import mock + +import pytest + +pytest.importorskip( + "google.cloud.iamconnectorcredentials_v1alpha", + reason="Requires google-cloud-iamconnectorcredentials", +) + +from google.adk import Agent +from google.adk import Runner +from google.adk.auth.auth_tool import AuthConfig +from google.adk.auth.credential_manager import CredentialManager +from google.adk.integrations.agent_identity import gcp_auth_provider +from google.adk.integrations.agent_identity import GcpAuthProvider +from google.adk.integrations.agent_identity import GcpAuthProviderScheme +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.tools.base_authenticated_tool import BaseAuthenticatedTool +from google.cloud.iamconnectorcredentials_v1alpha import RetrieveCredentialsMetadata +from google.cloud.iamconnectorcredentials_v1alpha import RetrieveCredentialsRequest +from google.cloud.iamconnectorcredentials_v1alpha import RetrieveCredentialsResponse +from google.genai import types + +from tests.unittests import testing_utils + +DUMMY_TOKEN = "mock-token-3legged" +TEST_CONNECTOR_3LO = ( + "projects/my-project/locations/some-location/connectors/test-connector-3lo" +) + + +class DummyTool(BaseAuthenticatedTool): + + def __init__(self, auth_config: AuthConfig) -> None: + super().__init__( + name="dummy_tool", + description="Dummy tool for testing 3LO.", + auth_config=auth_config, + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type="OBJECT", + properties={}, + ), + ) + + async def _run_async_impl( + self, *, args: dict[str, Any] | None, tool_context: Any, credential: Any + ) -> Any: + # Extract and return the token to prove the provider gave us the expected credential + if credential.http and credential.http.credentials: + return credential.http.credentials.token + if credential.oauth2 and credential.oauth2.access_token: + return credential.oauth2.access_token + + return None + + +@dataclasses.dataclass +class _MockOperation: + done: bool + response_obj: Any = None + metadata_obj: Any = None + error: Any = None + metadata: Any = dataclasses.field(init=False, default=None) + response: Any = dataclasses.field(init=False, default=None) + operation: Any = dataclasses.field(init=False) + + def __post_init__(self) -> None: + if self.metadata_obj: + self.metadata = mock.Mock() + self.metadata.value = RetrieveCredentialsMetadata.serialize( + self.metadata_obj + ) + if self.response_obj: + self.response = mock.Mock() + self.response.value = RetrieveCredentialsResponse.serialize( + self.response_obj + ) + self.operation = self + + def HasField(self, field_name: str) -> bool: + return getattr(self, field_name, None) is not None + + +class MockGcpClient: + """Lightweight in-memory mock for Agent Identity Credentials service 3LO Consent Flow.""" + + def __init__(self) -> None: + self.finalized_connectors = set() + + def retrieve_credentials( + self, + request: RetrieveCredentialsRequest | dict[str, Any] | None = None, + **kwargs: Any, + ) -> _MockOperation: + connector = ( + request.get("connector") + if isinstance(request, dict) + else getattr(request, "connector", None) + ) + + if connector in self.finalized_connectors: + mock_credential = RetrieveCredentialsResponse( + token=DUMMY_TOKEN, header="Authorization: Bearer" + ) + return _MockOperation(done=True, response_obj=mock_credential) + + # Otherwise, return Consent Required + # Auto-finalize for the next call to simulate user approval flow + self.finalized_connectors.add(connector) + + mock_metadata = RetrieveCredentialsMetadata( + uri_consent_required=RetrieveCredentialsMetadata.UriConsentRequired( + authorization_uri="http://mock-auth-uri", + consent_nonce="mock-consent-nonce", + ) + ) + return _MockOperation(done=False, metadata_obj=mock_metadata) + + +# Mocked execution; pin to a single LLM backend to avoid duplicate runs. +@pytest.mark.parametrize("llm_backend", ["GOOGLE_AI"], indirect=True) +@pytest.mark.asyncio +async def test_gcp_agent_identity_3lo_user_consent_flow() -> None: + # Clear registry to isolate tests + CredentialManager._auth_provider_registry._providers.clear() + + # 1. Setup mocked GCP Client to simulate stateful 3LO process + mock_gcp_client = MockGcpClient() + + with mock.patch.object( + gcp_auth_provider, + "Client", + autospec=True, + ) as mock_client_cls: + mock_client_cls.return_value.retrieve_credentials.side_effect = ( + mock_gcp_client.retrieve_credentials + ) + + # 2. Configure Auth and DummyTool + auth_scheme = GcpAuthProviderScheme( + name=TEST_CONNECTOR_3LO, + scopes=["test-scope"], + continue_uri="https://example.com/continue", + ) + auth_config = AuthConfig(auth_scheme=auth_scheme) + dummy_tool = DummyTool(auth_config=auth_config) + + # 3. Setup LLM, Agent, and Runner + # We mock the LLM to just issue the tool call to 'dummy_tool' + mock_model = testing_utils.MockModel.create( + responses=[ + types.Part.from_function_call(name="dummy_tool", args={}), + "I am waiting for your authorization.", + "Tool executed successfully.", + ] + ) + + agent = Agent( + name="test_agent", + model=mock_model, + instruction="You are an agent. Use the dummy_tool when needed.", + tools=[dummy_tool], + ) + + runner = Runner( + app_name="test_mcp_3lo_app", + agent=agent, + session_service=InMemorySessionService(), + auto_create_session=True, + ) + + # 4. Register Auth Provider + CredentialManager.register_auth_provider(GcpAuthProvider()) + + # 5. Execute Flow + session = await runner.session_service.create_session( + app_name="test_mcp_3lo_app", user_id="test_user" + ) + + event_list = [] + + # Step 5a: User sends message, Agent requests credential + async for event in runner.run_async( + user_id="test_user", + session_id=session.id, + new_message=types.UserContent( + parts=[types.Part(text="Get me the token.")] + ), + ): + event_list.append(event) + + def _find_auth_request_event(events): + for event in events: + for part in event.content.parts: + if ( + part.function_call + and part.function_call.name == "adk_request_credential" + ): + return event + return None + + auth_request_event = _find_auth_request_event(event_list) + + assert ( + auth_request_event + ), "Expected adk_request_credential tool call not found." + + # Step 5b: Simulate User Consent + call_part = next( + p for p in auth_request_event.content.parts if p.function_call + ) + request_auth_config = call_part.function_call.args.get("authConfig", {}) + + assert ( + request_auth_config.get("exchangedAuthCredential", {}) + .get("oauth2", {}) + .get("nonce") + == "mock-consent-nonce" + ) + + # Step 5c: User acknowledges credential request + response_part = types.Part.from_function_response( + name="adk_request_credential", response=request_auth_config + ) + response_part.function_response.id = call_part.function_call.id + + final_response_parts = [] + async for event in runner.run_async( + user_id="test_user", + session_id=session.id, + new_message=types.UserContent(parts=[response_part]), + ): + event_list.append(event) + if event.content: + for part in event.content.parts: + if part.text: + final_response_parts.append(part.text) + + final_response_text = "".join(final_response_parts) + + # 6. Assertions + + # Assert GCP Agent Identity client was invoked for credentials twice + # (Initial Request + Post-Consent call) + assert mock_client_cls.return_value.retrieve_credentials.call_count == 2 + expected_request = RetrieveCredentialsRequest( + connector=TEST_CONNECTOR_3LO, + user_id="test_user", + scopes=["test-scope"], + continue_uri="https://example.com/continue", + force_refresh=False, + ) + mock_client_cls.return_value.retrieve_credentials.assert_called_with( + expected_request + ) + + assert "Tool executed successfully." in final_response_text + + # Validate requests received by the mock model + requests = mock_model.requests + # Events: + # 1. User Input (Get me the token.) + # 2. LLM (I am waiting for your authorization.) + # 3. LLM (Tool executed successfully.) + assert len(requests) == 3 + + # Extract the function response from the prompt payload sent to the LLM + last_request = requests[-1] + function_response = next( + ( + p.function_response + for p in last_request.contents[-1].parts + if p.function_response + ), + None, + ) + + assert function_response is not None + assert function_response.name == "dummy_tool" + assert DUMMY_TOKEN in str(function_response.response) diff --git a/tests/unittests/artifacts/test_artifact_service.py b/tests/unittests/artifacts/test_artifact_service.py index 25294d4909..8b82397097 100644 --- a/tests/unittests/artifacts/test_artifact_service.py +++ b/tests/unittests/artifacts/test_artifact_service.py @@ -744,6 +744,69 @@ async def test_file_save_artifact_rejects_out_of_scope_paths( ) +@pytest.mark.asyncio +@pytest.mark.parametrize( + "user_id", + [ + "../escape", + "../../etc", + "foo/../../bar", + "valid/../..", + "..", + ".", + "has/slash", + "back\\slash", + "null\x00byte", + "", + ], +) +async def test_file_save_artifact_rejects_traversal_in_user_id( + tmp_path, user_id +): + """FileArtifactService rejects user_id values that escape root_dir.""" + artifact_service = FileArtifactService(root_dir=tmp_path / "artifacts") + part = types.Part(text="content") + with pytest.raises(InputValidationError): + await artifact_service.save_artifact( + app_name="myapp", + user_id=user_id, + session_id="sess123", + filename="safe.txt", + artifact=part, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "session_id", + [ + "../escape", + "../../tmp", + "foo/../../bar", + "..", + ".", + "has/slash", + "back\\slash", + "null\x00byte", + "", + ], +) +async def test_file_save_artifact_rejects_traversal_in_session_id( + tmp_path, session_id +): + """FileArtifactService rejects session_id values that escape root_dir.""" + artifact_service = FileArtifactService(root_dir=tmp_path / "artifacts") + part = types.Part(text="content") + with pytest.raises(InputValidationError): + await artifact_service.save_artifact( + app_name="myapp", + user_id="user123", + session_id=session_id, + filename="safe.txt", + artifact=part, + ) + + @pytest.mark.asyncio async def test_file_save_artifact_rejects_absolute_path_within_scope(tmp_path): """Absolute filenames are rejected even when they point inside the scope.""" diff --git a/tests/unittests/cli/utils/test_cli_deploy.py b/tests/unittests/cli/utils/test_cli_deploy.py index ed1aa8cc60..cf89863939 100644 --- a/tests/unittests/cli/utils/test_cli_deploy.py +++ b/tests/unittests/cli/utils/test_cli_deploy.py @@ -240,6 +240,20 @@ def test_agent_engine_app_template_compiles_with_windows_paths() -> None: compile(rendered, "", "exec") +def test_print_agent_engine_url() -> None: + """It should print the correct URL for a fully-qualified resource name.""" + with mock.patch("click.secho") as mocked_secho: + cli_deploy._print_agent_engine_url( + "projects/my-project/locations/us-central1/reasoningEngines/123456" + ) + mocked_secho.assert_called_once() + call_args = mocked_secho.call_args[0][0] + assert "my-project" in call_args + assert "us-central1" in call_args + assert "123456" in call_args + assert "playground" in call_args + + @pytest.mark.parametrize("include_requirements", [True, False]) def test_to_agent_engine_happy_path( monkeypatch: pytest.MonkeyPatch, diff --git a/tests/unittests/evaluation/test_evaluation_generator.py b/tests/unittests/evaluation/test_evaluation_generator.py index a4aa8691fd..db30bfac75 100644 --- a/tests/unittests/evaluation/test_evaluation_generator.py +++ b/tests/unittests/evaluation/test_evaluation_generator.py @@ -226,6 +226,52 @@ def test_convert_multi_agent_final_responses( assert intermediate_events[0].author == "agent1" assert intermediate_events[0].content.parts[0].text == "First response" + def test_invocation_without_user_event_is_skipped(self): + """Invocations with no user-authored event must be skipped. + + Regression test for https://github.com/google/adk-python/issues/3760. + When a session contains an invocation_id whose events are all authored by + agents or tools (no 'user' event), convert_events_to_eval_invocations used + to leave user_content as a bare string, causing a Pydantic ValidationError + from Invocation.user_content which requires genai_types.Content. + The fix skips such invocations because they represent internal/system-driven + turns that are not meaningful for evaluation. + """ + events = [ + _build_event("agent", [types.Part(text="agent-only event")], "inv1"), + ] + + # Must not raise a Pydantic ValidationError. + invocations = EvaluationGenerator.convert_events_to_eval_invocations(events) + + assert ( + invocations == [] + ), "Invocations without a user event should be skipped." + + def test_mixed_invocations_skips_only_agent_only_ones(self): + """Only agent-only invocations are skipped; normal invocations are kept. + + Regression test for https://github.com/google/adk-python/issues/3760. + """ + events = [ + # inv1: normal user+agent turn — should be kept. + _build_event("user", [types.Part(text="Hello")], "inv1"), + _build_event("agent", [types.Part(text="Hi there!")], "inv1"), + # inv2: agent-only turn (e.g. background/system task) — should be skipped. + _build_event("agent", [types.Part(text="Internal work")], "inv2"), + # inv3: normal user+agent turn — should be kept. + _build_event("user", [types.Part(text="Follow-up")], "inv3"), + _build_event("agent", [types.Part(text="Sure!")], "inv3"), + ] + + invocations = EvaluationGenerator.convert_events_to_eval_invocations(events) + + assert len(invocations) == 2 + assert invocations[0].invocation_id == "inv1" + assert invocations[0].user_content.parts[0].text == "Hello" + assert invocations[1].invocation_id == "inv3" + assert invocations[1].user_content.parts[0].text == "Follow-up" + class TestGetAppDetailsByInvocationId: """Test cases for EvaluationGenerator._get_app_details_by_invocation_id method.""" diff --git a/tests/unittests/integrations/agent_identity/test_gcp_auth_provider.py b/tests/unittests/integrations/agent_identity/test_gcp_auth_provider.py new file mode 100644 index 0000000000..f638d0088a --- /dev/null +++ b/tests/unittests/integrations/agent_identity/test_gcp_auth_provider.py @@ -0,0 +1,434 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import Mock +from unittest.mock import patch + +import pytest + +pytest.importorskip( + "google.cloud.iamconnectorcredentials_v1alpha", + reason="Requires google-cloud-iamconnectorcredentials", +) + +from google.adk.agents.callback_context import CallbackContext +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_tool import AuthConfig +from google.adk.flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME +from google.adk.integrations.agent_identity import gcp_auth_provider +from google.adk.integrations.agent_identity import GcpAuthProvider +from google.adk.integrations.agent_identity import GcpAuthProviderScheme +from google.adk.sessions.session import Session +from google.cloud.iamconnectorcredentials_v1alpha import RetrieveCredentialsMetadata +from google.cloud.iamconnectorcredentials_v1alpha import RetrieveCredentialsResponse +from google.longrunning.operations_pb2 import Operation +from google.protobuf.any_pb2 import Any +from google.rpc.status_pb2 import Status + + +@pytest.fixture +def mock_client(): + return Mock(spec=gcp_auth_provider.Client) + + +@pytest.fixture +def provider(mock_client): + return GcpAuthProvider(client=mock_client) + + +@pytest.fixture +def auth_config(): + scheme = GcpAuthProviderScheme( + name="projects/test-project/locations/global/connectors/test-connector", + scopes=["test-scope"], + continue_uri="https://example.com/continue", + ) + return Mock(spec=AuthConfig, auth_scheme=scheme) + + +@pytest.fixture +def mock_operation(mocker, mock_client): + op = Operation(done=True) + + class DummyCall: + + def __init__(self, operation): + self.operation = operation + + mock_client.retrieve_credentials.return_value = DummyCall(op) + return op + + +@pytest.fixture +def context(): + context = Mock(spec=CallbackContext) + context.user_id = "user" + context.function_call_id = "call_123" + session = Mock(spec=Session) + session.events = [] + context.session = session + + return context + + +@patch.dict(gcp_auth_provider.os.environ, clear=True) +@patch.object(gcp_auth_provider, "Client") +def test_get_client_uses_rest_transport(mock_client_class): + provider = GcpAuthProvider() + provider._get_client() + + mock_client_class.assert_called_once() + _, kwargs = mock_client_class.call_args + assert kwargs.get("transport") == "rest" + + +@patch.dict( + gcp_auth_provider.os.environ, + {"IAM_CONNECTOR_CREDENTIALS_TARGET_HOST": "some-host"}, +) +@patch.object(gcp_auth_provider, "Client") +@patch.object(gcp_auth_provider, "ClientOptions") +def test_get_client_with_env_var(mock_client_options_class, mock_client_class): + provider = GcpAuthProvider() + client = provider._get_client() + + assert client == mock_client_class.return_value + mock_client_options_class.assert_called_once_with(api_endpoint="some-host") + mock_client_class.assert_called_once_with( + client_options=mock_client_options_class.return_value, transport="rest" + ) + + +# ============================================================================== +# Non-interactive auth flows (API key and 2-legged OAuth) +# ============================================================================== + + +async def test_get_auth_credential_raises_error_for_invalid_auth_scheme( + provider, context +): + """Test get_auth_credential raises ValueError for invalid auth scheme.""" + invalid_auth_config = Mock(spec=AuthConfig) + invalid_auth_config.auth_scheme = Mock() # Not GcpAuthProviderScheme + + with pytest.raises(ValueError, match="Expected GcpAuthProviderScheme, got"): + await provider.get_auth_credential(invalid_auth_config, context) + + +async def test_get_auth_credential_raises_error_if_context_is_missing( + provider, auth_config +): + """Test get_auth_credential raises ValueError if context is missing.""" + with pytest.raises( + ValueError, + match="GcpAuthProvider requires a context with a valid user_id", + ): + await provider.get_auth_credential(auth_config, context=None) + + +async def test_get_auth_credential_raises_error_if_user_id_is_missing( + provider, auth_config +): + """Test get_auth_credential raises ValueError if user_id is missing.""" + context = Mock(spec=CallbackContext) + context.user_id = None + with pytest.raises( + ValueError, + match="GcpAuthProvider requires a context with a valid user_id", + ): + await provider.get_auth_credential(auth_config, context=context) + + +async def test_get_auth_credential_returns_credential_if_available_immediately( + mock_client, + mock_operation, + auth_config, + context, + provider, +): + """Test get_auth_credential returns credential if available immediately.""" + mock_credential = RetrieveCredentialsResponse( + header="Authorization: Bearer", token="test-token" + ) + mock_operation.response.value = RetrieveCredentialsResponse.serialize( + mock_credential + ) + + auth_credential = await provider.get_auth_credential(auth_config, context) + + assert auth_credential.auth_type == AuthCredentialTypes.HTTP + assert auth_credential.http.scheme == "bearer" + assert auth_credential.http.credentials.token == "test-token" + mock_client.retrieve_credentials.assert_called_once() + + +async def test_get_auth_credential_raises_error_if_upstream_returns_empty_header( + mock_operation, + auth_config, + context, + provider, +): + """Test get_auth_credential raises RuntimeError for empty header.""" + mock_credential = RetrieveCredentialsResponse(header="", token="test-token") + mock_operation.response.value = RetrieveCredentialsResponse.serialize( + mock_credential + ) + + with pytest.raises( + ValueError, + match=( + "Received either empty header or token from Agent Identity" + " Credentials service." + ), + ): + await provider.get_auth_credential(auth_config, context) + + +async def test_get_auth_credential_raises_error_if_upstream_returns_empty_token( + mock_operation, + auth_config, + context, + provider, +): + """Test get_auth_credential raises RuntimeError for empty token.""" + mock_credential = RetrieveCredentialsResponse( + header="Authorization: Bearer", token="" + ) + mock_operation.response.value = RetrieveCredentialsResponse.serialize( + mock_credential + ) + + with pytest.raises( + ValueError, + match=( + "Received either empty header or token from Agent Identity" + " Credentials service." + ), + ): + await provider.get_auth_credential(auth_config, context) + + +async def test_get_auth_credential_returns_credential_if_upstream_returns_custom_header( + mock_operation, + auth_config, + context, + provider, +): + """Test get_auth_credential returns valid credential for custom header and sets X-GOOG-API-KEY header.""" + mock_credential = RetrieveCredentialsResponse( + header="some-x-api-key", token="test-token" + ) + mock_operation.response.value = RetrieveCredentialsResponse.serialize( + mock_credential + ) + + auth_credential = await provider.get_auth_credential(auth_config, context) + + assert auth_credential.auth_type == AuthCredentialTypes.HTTP + assert not auth_credential.http.scheme + assert auth_credential.http.credentials.token is None + assert auth_credential.http.additional_headers == { + "some-x-api-key": "test-token", + "X-GOOG-API-KEY": "test-token", + } + + +async def test_get_auth_credential_raises_error_if_upstream_operation_errors( + mock_operation, auth_config, context, provider +): + """Test get_auth_credential raises RuntimeError for failed operations.""" + mock_operation.error.message = "OAuth server error" + mock_operation.done = False + + with pytest.raises( + RuntimeError, match="Operation failed: OAuth server error" + ): + await provider.get_auth_credential(auth_config, context) + + +async def test_get_auth_credential_raises_error_if_upstream_call_fails( + mock_client, auth_config, context, provider +): + """Test get_auth_credential raises RuntimeError for failed calls.""" + mock_client.retrieve_credentials.side_effect = Exception( + "API Quota Exhausted" + ) + + with pytest.raises( + RuntimeError, + match="Failed to retrieve credential for user 'user' on connector", + ) as exc_info: + await provider.get_auth_credential(auth_config, context) + + # Assert that the original Exception is the chained cause! + assert str(exc_info.value.__cause__) == "API Quota Exhausted" + + +@patch.object(gcp_auth_provider.time, "time") +async def test_get_auth_credential_raises_error_if_polling_times_out( + mock_time, + mock_operation, + auth_config, + context, + provider, +): + """Test get_auth_credential raises RuntimeError if polling times out.""" + + # Force the operation into the polling loop state + meta_pb = RetrieveCredentialsMetadata.pb()() + meta_pb.consent_pending.SetInParent() + meta = RetrieveCredentialsMetadata.deserialize(meta_pb.SerializeToString()) + mock_operation.metadata.value = RetrieveCredentialsMetadata.serialize(meta) + + # First call sets start_time=0.0, second call checks time > timeout + # (20.0 > 10.0) + mock_time.side_effect = [0.0, 20.0] + + mock_metadata = Mock(spec=RetrieveCredentialsMetadata) + mock_metadata.consent_pending = True + mock_metadata.uri_consent_required = False + mock_operation.done = True + mock_operation.ClearField("error") + mock_client = Mock(spec=gcp_auth_provider.Client) + mock_client.retrieve_credentials.side_effect = Exception( + "Timeout waiting for credentials." + ) + provider._client = mock_client + + with pytest.raises( + RuntimeError, + match="Failed to retrieve credential for user 'user' on connector", + ) as exc_info: + await provider.get_auth_credential(auth_config, context) + + assert "Timeout waiting for credentials." in str(exc_info.value.__cause__) + + +# ============================================================================== +# Interactive Auth Flows (3-legged OAuth for User Consents) +# ============================================================================== + + +async def test_get_auth_credential_initiates_user_consent( + mock_operation, auth_config, context, provider +): + # Explicitly set the mock behavior for this test + expected_uri = "https://example.com/auth" + expected_nonce = "sample-nonce-123" + meta = RetrieveCredentialsMetadata({ + "uri_consent_required": { + "authorization_uri": expected_uri, + "consent_nonce": expected_nonce, + } + }) + mock_operation.metadata.value = RetrieveCredentialsMetadata.serialize(meta) + mock_operation.done = False + # Assert that there is no prior user consent completion event + assert not context.session.events + + credential = await provider.get_auth_credential(auth_config, context) + + assert credential is not None + assert credential.auth_type == AuthCredentialTypes.OAUTH2 + assert credential.oauth2.auth_uri == expected_uri + assert credential.oauth2.nonce == expected_nonce + + +async def test_get_auth_credential_returns_fresh_auth_uri_for_repeated_requests( + mock_client, mock_operation, auth_config, context, provider +): + """Test that repeated calls fetch fresh auth URIs if consent is still pending.""" + # Arrange: Explicit initial URI + initial_uri = "https://example.com/auth" + initial_nonce = "initial-nonce-123" + meta1 = RetrieveCredentialsMetadata({ + "uri_consent_required": { + "authorization_uri": initial_uri, + "consent_nonce": initial_nonce, + } + }) + mock_operation.metadata.value = RetrieveCredentialsMetadata.serialize(meta1) + mock_operation.done = False + + credential1 = await provider.get_auth_credential(auth_config, context) + assert credential1.oauth2.auth_uri == initial_uri + assert credential1.oauth2.nonce == initial_nonce + + # Arrange: Explicit new URI for the second call + fresh_auth_uri = "https://example.com/auth_new" + fresh_nonce = "fresh-nonce-456" + meta2 = RetrieveCredentialsMetadata({ + "uri_consent_required": { + "authorization_uri": fresh_auth_uri, + "consent_nonce": fresh_nonce, + } + }) + mock_operation.metadata.value = RetrieveCredentialsMetadata.serialize(meta2) + + credential2 = await provider.get_auth_credential(auth_config, context) + + assert mock_client.retrieve_credentials.call_count == 2 + assert credential2.oauth2.auth_uri == fresh_auth_uri + assert credential2.oauth2.nonce == fresh_nonce + + +async def test_get_auth_credential_returns_token_if_consent_was_completed( + mock_operation, auth_config, context, provider +): + # Setup mock credential for successful credential retrieval + mock_credential = RetrieveCredentialsResponse( + header="Authorization: Bearer", token="test-token" + ) + mock_operation.response.value = RetrieveCredentialsResponse.serialize( + mock_credential + ) + + # Create mock events + # 1. FunctionCall event for adk_request_credential + function_call = Mock() + function_call.id = "auth-req-1" + function_call.name = REQUEST_EUC_FUNCTION_CALL_NAME + function_call.args = {"function_call_id": "call-123"} + + event1 = Mock() + event1.get_function_calls.return_value = [function_call] + event1.get_function_responses.return_value = [] + + # 2. FunctionResponse event for adk_request_credential + function_response = Mock() + function_response.id = "auth-req-1" + function_response.name = REQUEST_EUC_FUNCTION_CALL_NAME + + event2 = Mock() + event2.get_function_calls.return_value = [] + event2.get_function_responses.return_value = [function_response] + + # Setup tool context and event history (order of events matters) + context.session.events = [event1, event2] + context.function_call_id = "call-123" + + # Also set uri_consent_required to True-ish so it enters the check block + meta = RetrieveCredentialsMetadata( + uri_consent_required=RetrieveCredentialsMetadata.UriConsentRequired() + ) + mock_operation.metadata.value = RetrieveCredentialsMetadata.serialize(meta) + + # Execute + auth_credential = await provider.get_auth_credential(auth_config, context) + + # Verify + assert auth_credential is not None + assert auth_credential.auth_type == AuthCredentialTypes.HTTP + assert auth_credential.http.scheme == "bearer" + assert auth_credential.http.credentials.token == "test-token" diff --git a/tests/unittests/integrations/agent_registry/test_agent_registry.py b/tests/unittests/integrations/agent_registry/test_agent_registry.py index 1a454131ef..dd1678d780 100644 --- a/tests/unittests/integrations/agent_registry/test_agent_registry.py +++ b/tests/unittests/integrations/agent_registry/test_agent_registry.py @@ -637,3 +637,48 @@ def test_get_model_name_raises_value_error_if_no_uri( mock_get_endpoint.return_value = {} with pytest.raises(ValueError, match="Connection URI not found"): registry.get_model_name("test-endpoint") + + @patch.object(AgentRegistry, "_make_request") + def test_get_mcp_toolset_with_binding(self, mock_make_request, registry): + def side_effect(*args, **kwargs): + if args[0] == "test-mcp": + return { + "displayName": "TestPrefix", + "mcpServerId": "server-456", + "interfaces": [{ + "url": "https://mcp.com", + "protocolBinding": "JSONRPC", + }], + } + if args[0] == "bindings": + return { + "bindings": [{ + "target": { + "identifier": ( + "urn:mcp:projects-123:projects:123:locations:l:mcpServers:server-456" + ) + }, + "authProviderBinding": { + "authProvider": ( + "projects/123/locations/l/authProviders/ap-789" + ) + }, + }] + } + return {} + + mock_make_request.side_effect = side_effect + + registry._credentials.token = "token" + registry._credentials.refresh = MagicMock() + + toolset = registry.get_mcp_toolset( + "test-mcp", continue_uri="https://override.com/continue" + ) + assert isinstance(toolset, McpToolset) + assert toolset._auth_scheme is not None + assert ( + toolset._auth_scheme.name + == "projects/123/locations/l/authProviders/ap-789" + ) + assert toolset._auth_scheme.continue_uri == "https://override.com/continue"