Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- `CrossAppAccessFlow.start()` now accepts an optional `resource` parameter (RFC 8707), forwarded to the token exchange alongside `audience` and `scope`.
- `OAuth2Error` now exposes an `additional_fields` mapping containing any non-standard keys returned in the error response body, so server-specific remediation hints are no longer discarded.
- `OAuth2Error.from_response()` classmethod builds an error from a parsed OAuth2 error response body, mapping standard RFC 6749 fields to their attributes and collecting the rest into `additional_fields`.

## 0.2.0

Expand Down
33 changes: 33 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,39 @@ flow = CrossAppAccessFlow(

</details>

## Error Handling

Authentication flows raise `OAuth2Error` when the authorization server
returns an error response, or when the SDK detects a protocol violation
locally (e.g., a `state` mismatch on the authorization-code callback).

```python
from okta_client.authfoundation import OAuth2Error

try:
token = await flow.start(...)
except OAuth2Error as err:
print(err.error) # RFC 6749 error code, e.g. "invalid_grant"
print(err.error_description) # Human-readable description (if provided)
print(err.error_uri) # Documentation link (if provided)
print(err.status_code) # HTTP status (server responses only)
print(err.request_id) # Request ID header (server responses only)
```

Servers sometimes return additional keys alongside the standard fields —
for example `required_acr` and `max_age` on a step-up challenge, or
Okta-specific `errorCauses` / `errorId` values. Any keys the SDK doesn't
already model are preserved verbatim on `OAuth2Error.additional_fields`:

```python
except OAuth2Error as err:
if err.error == "interaction_required":
required_acr = err.additional_fields.get("required_acr")
# ...re-prompt the user at the requested assurance level
```

Locally-raised errors (no server payload) leave `additional_fields` empty.

## Listeners

A common pattern within this SDK is the use of "Listeners" which enable developers to observe key events within the SDK's lifecycle. This permits you to implement some protocol within your application, and add your class instance as a listener to the client or flow you would like to observe.
Expand Down
8 changes: 3 additions & 5 deletions src/okta_client/authfoundation/oauth2/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable

from okta_client.authfoundation.oauth2.requests.oauth_authorization_server import OAuthAuthorizationServerRequest
from okta_client.authfoundation.utils import coerce_optional_sequence, coerce_optional_str
from okta_client.authfoundation.utils import coerce_optional_sequence

from ..coalesced_result import CoalescedResult
from ..networking import APIClient, APIClientListener, APIResponse, NetworkInterface
Expand Down Expand Up @@ -413,10 +413,8 @@ def _raise_for_oauth2_error(
except Exception:
error = None
if error is None and ("error" in result or response.status_code >= 400):
error = OAuth2Error(
error=str(result.get("error", "oauth2_error")),
error_description=coerce_optional_str(result.get("error_description")),
error_uri=coerce_optional_str(result.get("error_uri")),
error = OAuth2Error.from_response(
result,
status_code=response.status_code,
request_id=response.request_id,
)
Expand Down
37 changes: 36 additions & 1 deletion src/okta_client/authfoundation/oauth2/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@

from __future__ import annotations

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Mapping

_STANDARD_FIELDS = frozenset({"error", "error_description", "error_uri"})


def _coerce_optional_str(value: Any) -> str | None:
return None if value is None else str(value)


@dataclass
Expand All @@ -21,6 +28,7 @@ class OAuth2Error(Exception):
error_uri: str | None = None
status_code: int | None = None
request_id: str | None = None
additional_fields: Mapping[str, Any] = field(default_factory=dict)

def __str__(self) -> str:
"""Return a readable error string."""
Expand All @@ -30,3 +38,30 @@ def __str__(self) -> str:
if self.error_uri:
details.append(self.error_uri)
return ": ".join(details)

@classmethod
def from_response(
cls,
data: Mapping[str, Any],
*,
status_code: int | None = None,
request_id: str | None = None,
) -> "OAuth2Error":
"""Build an :class:`OAuth2Error` from a parsed OAuth2 error response body.

Standard RFC 6749 keys (``error``, ``error_description``, ``error_uri``)
are mapped to their dedicated attributes; any other keys are kept
verbatim on :attr:`additional_fields` so callers can inspect
server-specific remediation hints.

``error`` defaults to ``"oauth2_error"`` when the response body omits it
(e.g., a 5xx with no JSON ``error`` key).
"""
return cls(
error=str(data.get("error", "oauth2_error")),
error_description=_coerce_optional_str(data.get("error_description")),
error_uri=_coerce_optional_str(data.get("error_uri")),
status_code=status_code,
request_id=request_id,
additional_fields={k: v for k, v in data.items() if k not in _STANDARD_FIELDS},
)
11 changes: 2 additions & 9 deletions src/okta_client/authfoundation/oauth2/request_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
from collections.abc import Mapping
from typing import Any, Protocol, runtime_checkable

from okta_client.authfoundation.utils import coerce_optional_str

from ..networking import (
APIContentType,
APIParsingContext,
Expand Down Expand Up @@ -164,11 +162,6 @@ def accepts_type(self) -> APIContentType | None:

def parse_error(self, data: Mapping[str, Any]) -> Exception | None:
"""Parse standard OAuth2 error fields when present."""
error = data.get("error")
if not error:
if not data.get("error"):
return None
return OAuth2Error(
error=str(error),
error_description=coerce_optional_str(data.get("error_description")),
error_uri=coerce_optional_str(data.get("error_uri")),
)
return OAuth2Error.from_response(data)
74 changes: 74 additions & 0 deletions tests/test_oauth2_exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,77 @@ def test_oauth2_exchange_oauth_error() -> None:
assert error.error_description == "invalid credentials"
return
raise AssertionError("Expected OAuth2Error for invalid_grant")


def test_oauth2_error_preserves_server_additional_fields() -> None:
"""Non-standard fields in the token-endpoint error body are preserved on OAuth2Error."""
openid = OpenIdConfiguration.from_json(
{
"authorization_endpoint": "https://example.com/auth",
"token_endpoint": "https://example.com/token",
"jwks_uri": "https://example.com/keys",
}
)
token_body = json.dumps(
{
"error": "interaction_required",
"error_description": "step-up required",
"required_acr": "urn:okta:loa:2fa:any",
"max_age": 0,
}
).encode("utf-8")
discovery_body = json.dumps(
{
"issuer": "https://example.com",
"authorization_endpoint": "https://example.com/auth",
"token_endpoint": "https://example.com/token",
"jwks_uri": "https://example.com/keys",
}
).encode("utf-8")
jwks_body = json.dumps({"keys": []}).encode("utf-8")
network = DummyNetwork(
responses={
"https://example.com/.well-known/openid-configuration": RawResponse(
status_code=200, headers={}, body=discovery_body,
),
"https://example.com/keys?client_id=client": RawResponse(
status_code=200, headers={}, body=jwks_body,
),
"https://example.com/token": RawResponse(
status_code=400, headers={}, body=token_body,
),
}
)
client = OAuth2Client(
configuration=OAuth2ClientConfiguration(
issuer="https://example.com",
scope=["openid"],
client_authorization=ClientIdAuthorization(id="client"),
),
network=network,
)
request = TokenExchangeRequest(
_openid_configuration=openid,
_client_configuration=client.configuration,
username="user",
password="pass",
)

try:
asyncio.run(client.exchange(request))
except OAuth2Error as error:
assert error.error == "interaction_required"
assert error.additional_fields == {
"required_acr": "urn:okta:loa:2fa:any",
"max_age": 0,
}
# str() should remain unchanged (no extras appended).
assert str(error) == "interaction_required: step-up required"
return
raise AssertionError("Expected OAuth2Error for interaction_required")


def test_oauth2_error_default_additional_fields_is_empty() -> None:
"""Locally-raised OAuth2Errors have an empty additional_fields mapping."""
err = OAuth2Error(error="state_mismatch", error_description="bad state")
assert err.additional_fields == {}