diff --git a/README.md b/README.md index d135fa4..6e2cb16 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,8 @@ - [Get Application by Id](#get-application-by-id) - [Models](#models) - [Get Model by Name](#get-model-by-name) + - [User](#user) + - [Get Authenticated User Info](#get-authenticated-user-info) - [Toolsets](#toolsets) - [Get Toolset by Id](#get-toolset-by-id) - [Resource Permissions](#resource-permissions) @@ -856,6 +858,48 @@ ModelInfo( ) ``` +### User + +#### Get Authenticated User Info + +To retrieve information about the currently authenticated user: + +```python +# Sync +user_info = client.user.info() + +# Async +user_info = await async_client.user.info() +``` + +As a result, you will receive a `UserInfo` object. When authenticated with an +API key: + +```python +UserInfo( + roles=["default"], + project="PROJECT-NAME", + userClaims=None, +) +``` + +When authenticated with an access token: + +```python +UserInfo( + roles=["BA"], + project=None, + userClaims={ + "email": ["user_email"], + "sub": ["user_sub"], + }, +) +``` + +`userClaims` is returned as an opaque `dict` because its contents depend on the +identity provider. `UserInfo` also preserves any additional fields the DIAL +deployment may return, so forward compatibility is retained. + ### Toolsets #### Get Toolset by Id diff --git a/aidial_client/__init__.py b/aidial_client/__init__.py index edf195f..43f5f0b 100644 --- a/aidial_client/__init__.py +++ b/aidial_client/__init__.py @@ -12,6 +12,7 @@ from aidial_client.types.client_channel import SigninResult from aidial_client.types.model import ModelInfo, ModelLimits, ModelPricing from aidial_client.types.toolset import ToolsetInfo +from aidial_client.types.user import UserInfo __all__ = [ "Dial", @@ -32,4 +33,5 @@ "ModelPricing", "ModelLimits", "SigninResult", + "UserInfo", ] diff --git a/aidial_client/_client.py b/aidial_client/_client.py index 36313af..03eb914 100644 --- a/aidial_client/_client.py +++ b/aidial_client/_client.py @@ -120,6 +120,7 @@ def _init_resources(self) -> None: self.client_channel = resources.ClientChannel( http_client=self._http_client ) + self.user = resources.User(http_client=self._http_client) def _create_http_client(self) -> SyncHTTPClient: return SyncHTTPClient( @@ -224,6 +225,7 @@ def _init_resources(self) -> None: self.client_channel = resources.AsyncClientChannel( http_client=self._http_client ) + self.user = resources.AsyncUser(http_client=self._http_client) def _create_http_client(self) -> AsyncHTTPClient: return AsyncHTTPClient( diff --git a/aidial_client/resources/__init__.py b/aidial_client/resources/__init__.py index ed55b9a..170b700 100644 --- a/aidial_client/resources/__init__.py +++ b/aidial_client/resources/__init__.py @@ -10,6 +10,7 @@ ResourcePermissions, ) from aidial_client.resources.toolset import AsyncToolset, Toolset +from aidial_client.resources.user import AsyncUser, User from .application import Application, AsyncApplication from .bucket import AsyncBucket, Bucket @@ -40,4 +41,6 @@ "AsyncResourcePermissions", "ClientChannel", "AsyncClientChannel", + "User", + "AsyncUser", ] diff --git a/aidial_client/resources/user.py b/aidial_client/resources/user.py new file mode 100644 index 0000000..47c20b9 --- /dev/null +++ b/aidial_client/resources/user.py @@ -0,0 +1,19 @@ +from aidial_client._internal_types._http_request import FinalRequestOptions +from aidial_client.resources.base import AsyncResource, Resource +from aidial_client.types.user import UserInfo + + +class User(Resource): + def info(self) -> UserInfo: + return self.http_client.request( + cast_to=UserInfo, + options=FinalRequestOptions(method="GET", url="v1/user/info"), + ) + + +class AsyncUser(AsyncResource): + async def info(self) -> UserInfo: + return await self.http_client.request( + cast_to=UserInfo, + options=FinalRequestOptions(method="GET", url="v1/user/info"), + ) diff --git a/aidial_client/types/user.py b/aidial_client/types/user.py new file mode 100644 index 0000000..089fc95 --- /dev/null +++ b/aidial_client/types/user.py @@ -0,0 +1,11 @@ +from typing import Any + +from aidial_client._internal_types._model import ExtraAllowModel + + +class UserInfo(ExtraAllowModel): + """Information about the authenticated user or API key.""" + + roles: list[str] + project: str | None = None + userClaims: dict[str, Any] | None = None # depends on the IdP, so opaque diff --git a/tests/resources/test_user.py b/tests/resources/test_user.py new file mode 100644 index 0000000..3a5520f --- /dev/null +++ b/tests/resources/test_user.py @@ -0,0 +1,105 @@ +import httpx +import pytest + +from aidial_client import AsyncDial, Dial +from aidial_client._exception import DialException +from aidial_client.types.user import UserInfo +from tests.client_mock import get_async_client_mock, get_client_mock + +BASE_URL = "http://dial.core" + +USER_INFO_MOCK = { + "project": "PROJECT-NAME", + "roles": ["default"], +} + +USER_INFO_TOKEN_MOCK = { + "roles": ["BA"], + "userClaims": { + "email": ["user@example.com"], + "sub": ["user-123"], + }, +} + + +def test_get_user_info(): + client = get_client_mock(status_code=200, json_mock=USER_INFO_MOCK) + result = client.user.info() + assert isinstance(result, UserInfo) + assert result.project == "PROJECT-NAME" + assert result.roles == ["default"] + assert result.userClaims is None + + +@pytest.mark.asyncio +async def test_async_get_user_info(): + client = get_async_client_mock(status_code=200, json_mock=USER_INFO_MOCK) + result = await client.user.info() + assert isinstance(result, UserInfo) + assert result.project == "PROJECT-NAME" + assert result.roles == ["default"] + assert result.userClaims is None + + +def test_get_user_info_with_token_claims(): + client = get_client_mock(status_code=200, json_mock=USER_INFO_TOKEN_MOCK) + result = client.user.info() + assert isinstance(result, UserInfo) + assert result.project is None + assert result.roles == ["BA"] + assert result.userClaims == { + "email": ["user@example.com"], + "sub": ["user-123"], + } + + +def test_get_user_info_request_method_and_url(): + captured: list[httpx.Request] = [] + client = Dial(api_key="dummy", base_url=BASE_URL) + + def send_mock(request: httpx.Request, **kwargs): + captured.append(request) + return httpx.Response(200, request=request, json=USER_INFO_MOCK) + + client._http_client._internal_http_client.send = send_mock + client.user.info() + + assert len(captured) == 1 + assert captured[0].method == "GET" + assert captured[0].url.path == "/v1/user/info" + + +@pytest.mark.asyncio +async def test_async_get_user_info_request_method_and_url(): + captured: list[httpx.Request] = [] + client = AsyncDial(api_key="dummy", base_url=BASE_URL) + + async def send_mock(request: httpx.Request, **kwargs): + captured.append(request) + return httpx.Response(200, request=request, json=USER_INFO_MOCK) + + client._http_client._internal_http_client.send = send_mock + await client.user.info() + + assert len(captured) == 1 + assert captured[0].method == "GET" + assert captured[0].url.path == "/v1/user/info" + + +def test_get_user_info_http_error(): + client = get_client_mock( + status_code=401, + json_mock={"error": {"message": "Unauthorized", "type": "auth_error"}}, + ) + with pytest.raises(DialException): + client.user.info() + + +@pytest.mark.asyncio +async def test_async_get_user_info_http_error(): + client = get_async_client_mock( + status_code=401, + json_mock={"error": {"message": "Unauthorized", "type": "auth_error"}}, + ) + with pytest.raises(DialException): + await client.user.info()