diff --git a/msal/application.py b/msal/application.py index 084f9bf3..cadba9b2 100644 --- a/msal/application.py +++ b/msal/application.py @@ -242,6 +242,7 @@ class ClientApplication(object): ACQUIRE_TOKEN_FOR_CLIENT_ID = "730" ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID = "832" ACQUIRE_TOKEN_INTERACTIVE = "169" + ACQUIRE_TOKEN_BY_USER_FIC_ID = "950" GET_ACCOUNTS_ID = "902" REMOVE_ACCOUNT_ID = "903" @@ -2572,3 +2573,62 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP telemetry_context.update_telemetry(response) return response + + def acquire_token_by_user_federated_identity_credential( + self, scopes, assertion, username=None, user_object_id=None, + claims_challenge=None, **kwargs): + """Acquires a user-scoped token using the ``user_fic`` grant type. + + This method exchanges a federated identity credential (typically an + agent instance token from Leg 2 of the agent identity protocol) for + a user-scoped access token, enabling an agent to act on behalf of + a specific user. + + :param list[str] scopes: Scopes required by downstream API (a resource). + :param str assertion: + The federated identity credential token (e.g. the instance token + obtained from Leg 2 of the agent identity flow). + :param str username: + The target user's UPN (User Principal Name). + Mutually exclusive with ``user_object_id``. + :param str user_object_id: + The target user's Object ID. + Mutually exclusive with ``username``. + :param claims_challenge: + The claims_challenge parameter requests specific claims requested by the resource provider + in the form of a claims_challenge directive in the www-authenticate header to be + returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. + It is a string of a JSON object which contains lists of claims being requested from these locations. + + :return: A dict representing the json response from Microsoft Entra: + + - A successful response would contain "access_token" key, + - an error response would contain "error" and usually "error_description". + """ + # Input validation + if not assertion: + raise ValueError("assertion is required and must be non-empty") + if not username and not user_object_id: + raise ValueError( + "Either username or user_object_id must be provided") + if username and user_object_id: + raise ValueError( + "username and user_object_id are mutually exclusive") + + telemetry_context = self._build_telemetry_context( + self.ACQUIRE_TOKEN_BY_USER_FIC_ID) + response = _clean_up(self.client.obtain_token_by_user_fic( + scope=self._decorate_scope(scopes), + assertion=assertion, + username=username, + user_object_id=user_object_id, + headers=telemetry_context.generate_headers(), + data=dict( + kwargs.pop("data", {}), + claims=_merge_claims_challenge_and_capabilities( + self._client_capabilities, claims_challenge)), + **kwargs)) + if "access_token" in response: + response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP + telemetry_context.update_telemetry(response) + return response diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index 68b0e84e..33caca73 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -7,6 +7,7 @@ except ImportError: from urlparse import parse_qs, urlparse, urlunparse from urllib import urlencode, quote_plus +import inspect import logging import warnings import time @@ -104,6 +105,11 @@ def __init__( or a raw JWT assertion in bytes (which we will relay to http layer). It can also be a callable (recommended), so that we will do lazy creation of an assertion. + + The callable may accept zero arguments (legacy) or one argument. + When it accepts one argument, it will receive a dict containing + ``"client_id"``, ``"token_endpoint"``, and optionally ``"fmi_path"`` + (when an FMI path is set on the current request). client_assertion_type (str): The type of your :attr:`client_assertion` parameter. It is typically the value of :attr:`CLIENT_ASSERTION_TYPE_SAML2` or @@ -168,6 +174,41 @@ def __init__( # A workaround for requests not supporting session-wide timeout self._http_client.request, timeout=timeout) + @staticmethod + def _accepts_context(func): + """Check if a callable requires at least one positional argument. + + Returns True only when the callable has a positional parameter + **without** a default value. This ensures that legacy zero-arg + callables — including ``lambda token=token: token`` patterns + where every positional param has a default — are still invoked + with no arguments. + """ + try: + sig = inspect.signature(func) + for p in sig.parameters.values(): + if p.kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) and p.default is inspect.Parameter.empty: + return True + return False + except (ValueError, TypeError): + return False # Signature not inspectable; treat as zero-arg + + def _invoke_assertion_callable(self, assertion_callable, data=None): + """Invoke an assertion callable, passing context if it accepts one.""" + if self._accepts_context(assertion_callable): + context = { + "client_id": self.client_id, + "token_endpoint": self.configuration.get( + "token_endpoint", ""), + } + if data and data.get("fmi_path"): + context["fmi_path"] = data["fmi_path"] + return assertion_callable(context) + return assertion_callable() + def _build_auth_request_params(self, response_type, **kwargs): # response_type is a string defined in # https://tools.ietf.org/html/rfc6749#section-3.1.1 @@ -198,11 +239,11 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 # See https://tools.ietf.org/html/rfc7521#section-4.2 encoder = self.client_assertion_encoders.get( self.default_body["client_assertion_type"], lambda a: a) - _data["client_assertion"] = encoder( - self.client_assertion() # Do lazy on-the-fly computation - if callable(self.client_assertion) else self.client_assertion - ) # The type is bytes, which is preferable. See also: - # https://github.com/psf/requests/issues/4503#issuecomment-455001070 + if callable(self.client_assertion): + raw = self._invoke_assertion_callable(self.client_assertion, data) + else: + raw = self.client_assertion + _data["client_assertion"] = encoder(raw) _data.update(self.default_body) # It may contain authen parameters _data.update(data or {}) # So the content in data param prevails @@ -770,6 +811,34 @@ class initialization. data.update(scope=scope) return self._obtain_token("client_credentials", data=data, **kwargs) + def obtain_token_by_user_fic( + self, scope, assertion, username=None, user_object_id=None, + **kwargs): + """Obtain token using the ``user_fic`` grant type. + + This exchanges a federated identity credential (e.g. an agent + instance token) for a user-scoped access token. + + :param scope: Scopes for the target resource (already decorated + with OIDC scopes by the caller). + :param str assertion: The federated identity credential token. + :param str username: The target user's UPN (mutually exclusive + with *user_object_id*). + :param str user_object_id: The target user's Object ID (mutually + exclusive with *username*). + """ + data = kwargs.pop("data", {}) + data.update( + scope=scope, + user_federated_identity_credential=assertion, + client_info="1", + ) + if user_object_id: + data["user_id"] = str(user_object_id) + elif username: + data["username"] = username + return self._obtain_token("user_fic", data=data, **kwargs) + def __init__(self, server_configuration, client_id, on_obtaining_tokens=lambda event: None, # event is defined in _obtain_token(...) diff --git a/msal/token_cache.py b/msal/token_cache.py index d6e2a2b1..78999292 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -65,6 +65,11 @@ "token_type", "req_cnf", "key_id", + # user_fic grant parameters — these are standard body params for the + # user_fic flow; FIC tokens use normal user cache keys (not extended). + "user_federated_identity_credential", + "user_id", + "client_info", }) @@ -301,6 +306,7 @@ def make_clean_copy(dictionary, sensitive_fields): # Masks sensitive info event, data=make_clean_copy(event.get("data", {}), ( "password", "client_secret", "refresh_token", "assertion", + "user_federated_identity_credential", )), response=make_clean_copy(event.get("response", {}), ( "id_token_claims", # Provided by broker diff --git a/tests/test_application.py b/tests/test_application.py index 54da96c0..e56be943 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1,5 +1,6 @@ # Note: Since Aug 2019 we move all e2e tests into test_e2e.py, # so this test_application file contains only unit tests without dependency. +import base64 import json import logging import sys @@ -1090,4 +1091,356 @@ def mock_post(url, headers=None, *args, **kwargs): result = app.acquire_token_for_client([scope]) self.assertEqual(result[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE) self.assertEqual("AT_with_valid_scope1_valid_scope2_scopes", result.get("access_token")) - self.assertIsNone(result.get("scope"), "scope field is not returned when token comes from cache") \ No newline at end of file + self.assertIsNone(result.get("scope"), "scope field is not returned when token comes from cache") + + +def _build_user_fic_response(uid="user_oid", utid="tenant_id", access_token="user_at"): + """Build a mock user_fic response with client_info and id_token.""" + client_info = base64.b64encode(json.dumps({ + "uid": uid, "utid": utid, + }).encode()).decode("utf-8") + id_token_claims = { + "iss": "https://login.microsoftonline.com/tenant_id/v2.0", + "sub": "subject", + "aud": "agent_app_id", + "exp": time.time() + 3600, + "iat": time.time(), + "oid": uid, + "preferred_username": "user@contoso.com", + "tid": utid, + } + id_token = "header.%s.signature" % base64.b64encode( + json.dumps(id_token_claims).encode()).decode("utf-8") + return json.dumps({ + "access_token": access_token, + "expires_in": 3600, + "token_type": "Bearer", + "client_info": client_info, + "id_token": id_token, + "refresh_token": "a_refresh_token", + }) + + +@patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) +class TestUserFicProtocol(unittest.TestCase): + """Tests that acquire_token_by_user_federated_identity_credential sends correct POST body.""" + + def _make_app(self): + return ConfidentialClientApplication( + "agent_app_id", client_credential="secret", + authority="https://login.microsoftonline.com/my_tenant") + + def test_sends_correct_grant_type_and_params(self): + app = self._make_app() + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse(status_code=200, text=_build_user_fic_response()) + + result = app.acquire_token_by_user_federated_identity_credential( + ["https://graph.microsoft.com/.default"], + assertion="instance_token_t2", + username="user@contoso.com", + post=mock_post) + self.assertIn("access_token", result) + self.assertEqual("user_fic", captured_data.get("grant_type")) + self.assertEqual("instance_token_t2", + captured_data.get("user_federated_identity_credential")) + self.assertEqual("1", captured_data.get("client_info")) + self.assertEqual("agent_app_id", captured_data.get("client_id")) + + def test_scope_includes_oidc_scopes(self): + app = self._make_app() + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse(status_code=200, text=_build_user_fic_response()) + + app.acquire_token_by_user_federated_identity_credential( + ["https://graph.microsoft.com/.default"], + assertion="t2", username="user@contoso.com", post=mock_post) + scope_str = captured_data.get("scope", "") + for oidc_scope in ("openid", "offline_access", "profile"): + self.assertIn(oidc_scope, scope_str, + "OIDC scope '{}' should be present".format(oidc_scope)) + + def test_with_username_sends_username_not_user_id(self): + app = self._make_app() + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse(status_code=200, text=_build_user_fic_response()) + + app.acquire_token_by_user_federated_identity_credential( + ["scope"], assertion="t2", username="user@contoso.com", post=mock_post) + self.assertEqual("user@contoso.com", captured_data.get("username")) + self.assertNotIn("user_id", captured_data, + "user_id should NOT be in body when username is provided") + + def test_with_oid_sends_user_id_not_username(self): + app = self._make_app() + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse(status_code=200, text=_build_user_fic_response()) + + app.acquire_token_by_user_federated_identity_credential( + ["scope"], assertion="t2", + user_object_id="00000000-0000-0000-0000-000000000001", + post=mock_post) + self.assertEqual("00000000-0000-0000-0000-000000000001", + captured_data.get("user_id")) + self.assertNotIn("username", captured_data, + "username should NOT be in body when user_object_id is provided") + + +@patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) +class TestUserFicCacheBehavior(unittest.TestCase): + """Tests that user_fic tokens are stored in user cache with account info.""" + + def _make_app(self): + return ConfidentialClientApplication( + "agent_app_id", client_credential="secret", + authority="https://login.microsoftonline.com/my_tenant") + + def test_token_stored_in_user_cache_with_account(self): + app = self._make_app() + + def mock_post(url, headers=None, data=None, *args, **kwargs): + return MinimalResponse(status_code=200, text=_build_user_fic_response( + uid="user_oid", utid="tenant_id", access_token="fic_at")) + + result = app.acquire_token_by_user_federated_identity_credential( + ["https://graph.microsoft.com/.default"], + assertion="t2", username="user@contoso.com", post=mock_post) + self.assertIn("access_token", result) + + # Verify the account was created + accounts = app.get_accounts() + self.assertTrue(len(accounts) > 0, "Account should be created from user_fic response") + account = accounts[0] + self.assertEqual("user_oid.tenant_id", account["home_account_id"]) + + def test_token_not_stored_as_atext(self): + """user_fic tokens should use standard AccessToken type, not atext.""" + app = self._make_app() + + def mock_post(url, headers=None, data=None, *args, **kwargs): + return MinimalResponse(status_code=200, text=_build_user_fic_response()) + + app.acquire_token_by_user_federated_identity_credential( + ["https://graph.microsoft.com/.default"], + assertion="t2", username="user@contoso.com", post=mock_post) + + # Check the raw cache for credential type + at_entries = list(app.token_cache.search( + msal.TokenCache.CredentialType.ACCESS_TOKEN, query={})) + self.assertTrue(len(at_entries) > 0, "AT should be cached") + self.assertNotIn("ext_cache_key", at_entries[0], + "user_fic tokens should NOT have ext_cache_key") + + def test_acquire_token_silent_returns_cached_fic_token(self): + app = self._make_app() + + def mock_post(url, headers=None, data=None, *args, **kwargs): + return MinimalResponse(status_code=200, text=_build_user_fic_response( + uid="user_oid", utid="tenant_id", access_token="cached_fic_at")) + + app.acquire_token_by_user_federated_identity_credential( + ["https://graph.microsoft.com/.default"], + assertion="t2", username="user@contoso.com", post=mock_post) + + accounts = app.get_accounts() + self.assertTrue(len(accounts) > 0) + + # Silent call should return cached token without hitting network + silent_result = app.acquire_token_silent( + ["https://graph.microsoft.com/.default"], account=accounts[0]) + self.assertIn("access_token", silent_result) + self.assertEqual("cached_fic_at", silent_result["access_token"]) + + def test_oid_path_token_stored_and_retrievable_via_silent(self): + """user_fic with user_object_id should cache and retrieve like username.""" + app = self._make_app() + + def mock_post(url, headers=None, data=None, *args, **kwargs): + return MinimalResponse(status_code=200, text=_build_user_fic_response( + uid="user_oid", utid="tenant_id", access_token="oid_fic_at")) + + result = app.acquire_token_by_user_federated_identity_credential( + ["https://graph.microsoft.com/.default"], + assertion="t2", user_object_id="user_oid", post=mock_post) + self.assertIn("access_token", result) + + # Verify no ext_cache_key on cached token + at_entries = list(app.token_cache.search( + msal.TokenCache.CredentialType.ACCESS_TOKEN, query={})) + self.assertTrue(len(at_entries) > 0, "AT should be cached") + self.assertNotIn("ext_cache_key", at_entries[0], + "OID-path user_fic tokens should NOT have ext_cache_key") + + # Verify account and silent retrieval + accounts = app.get_accounts() + self.assertTrue(len(accounts) > 0) + silent_result = app.acquire_token_silent( + ["https://graph.microsoft.com/.default"], account=accounts[0]) + self.assertIn("access_token", silent_result) + self.assertEqual("oid_fic_at", silent_result["access_token"]) + + +@patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) +class TestUserFicInputValidation(unittest.TestCase): + """Tests that input validation rejects invalid parameters.""" + + def _make_app(self): + return ConfidentialClientApplication( + "agent_app_id", client_credential="secret", + authority="https://login.microsoftonline.com/my_tenant") + + def test_empty_assertion_raises(self): + app = self._make_app() + with self.assertRaises(ValueError): + app.acquire_token_by_user_federated_identity_credential( + ["scope"], assertion="", username="user@contoso.com") + + def test_none_assertion_raises(self): + app = self._make_app() + with self.assertRaises(ValueError): + app.acquire_token_by_user_federated_identity_credential( + ["scope"], assertion=None, username="user@contoso.com") + + def test_no_user_identifier_raises(self): + app = self._make_app() + with self.assertRaises(ValueError): + app.acquire_token_by_user_federated_identity_credential( + ["scope"], assertion="t2") + + def test_both_user_identifiers_raises(self): + app = self._make_app() + with self.assertRaises(ValueError): + app.acquire_token_by_user_federated_identity_credential( + ["scope"], assertion="t2", + username="user@contoso.com", + user_object_id="oid-123") + + def test_reserved_scopes_rejected(self): + app = self._make_app() + with self.assertRaises(ValueError): + app.acquire_token_by_user_federated_identity_credential( + ["openid"], assertion="t2", username="user@contoso.com") + + +@patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) +class TestAssertionCallbackContext(unittest.TestCase): + """Tests that assertion callbacks receive context when they accept arguments.""" + + def test_context_aware_callback_receives_fmi_path(self): + received_context = {} + + def assertion_with_context(context): + received_context.update(context) + return "assertion_value" + + app = ConfidentialClientApplication( + "client_id", + client_credential={"client_assertion": assertion_with_context}, + authority="https://login.microsoftonline.com/my_tenant") + + app.acquire_token_for_client( + ["scope"], fmi_path="agent_app_123", + post=lambda url, **kwargs: MinimalResponse( + status_code=200, text=json.dumps({ + "access_token": "an_at", "expires_in": 3600}))) + + self.assertEqual("client_id", received_context.get("client_id")) + self.assertIn("token_endpoint", received_context) + self.assertEqual("agent_app_123", received_context.get("fmi_path")) + + def test_context_aware_callback_omits_fmi_path_when_not_set(self): + received_context = {} + + def assertion_with_context(context): + received_context.update(context) + return "assertion_value" + + app = ConfidentialClientApplication( + "client_id", + client_credential={"client_assertion": assertion_with_context}, + authority="https://login.microsoftonline.com/my_tenant") + + app.acquire_token_for_client( + ["scope"], + post=lambda url, **kwargs: MinimalResponse( + status_code=200, text=json.dumps({ + "access_token": "an_at", "expires_in": 3600}))) + + self.assertEqual("client_id", received_context.get("client_id")) + self.assertNotIn("fmi_path", received_context) + + def test_legacy_zero_arg_callback_still_works(self): + call_count = [0] + + def legacy_callback(): + call_count[0] += 1 + return "legacy_assertion" + + app = ConfidentialClientApplication( + "client_id", + client_credential={"client_assertion": legacy_callback}, + authority="https://login.microsoftonline.com/my_tenant") + + result = app.acquire_token_for_client( + ["scope"], + post=lambda url, **kwargs: MinimalResponse( + status_code=200, text=json.dumps({ + "access_token": "an_at", "expires_in": 3600}))) + + self.assertIn("access_token", result) + self.assertEqual(1, call_count[0], "Legacy callback should be invoked once") + + def test_context_callback_type_error_not_swallowed(self): + """If a one-arg callback raises TypeError internally, it should propagate.""" + def buggy_callback(context): + raise TypeError("Bug inside callback") + + app = ConfidentialClientApplication( + "client_id", + client_credential={"client_assertion": buggy_callback}, + authority="https://login.microsoftonline.com/my_tenant") + + with self.assertRaises(TypeError, msg="Internal TypeError should propagate"): + app.acquire_token_for_client( + ["scope"], + post=lambda url, **kwargs: MinimalResponse( + status_code=200, text=json.dumps({ + "access_token": "an_at", "expires_in": 3600}))) + + def test_lambda_with_defaulted_param_treated_as_zero_arg(self): + """A lambda like ``lambda token=token: token`` should be treated as + zero-arg because all its positional params have defaults.""" + captured_value = "my_assertion_value" + assertion_callable = lambda token=captured_value: token # noqa: E731 + + app = ConfidentialClientApplication( + "client_id", + client_credential={"client_assertion": assertion_callable}, + authority="https://login.microsoftonline.com/my_tenant") + + captured_data = {} + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse( + status_code=200, text=json.dumps({ + "access_token": "an_at", "expires_in": 3600})) + + result = app.acquire_token_for_client(["scope"], post=mock_post) + self.assertIn("access_token", result) + # The assertion should be the string value, not a dict context object + self.assertEqual( + captured_value, captured_data.get("client_assertion"), + "Lambda with defaulted params should return its default value, " + "not receive a context dict") \ No newline at end of file