mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-15 22:26:25 +02:00
141 lines
4.4 KiB
Python
141 lines
4.4 KiB
Python
"""OAuth 2.0 / OIDC helper utilities."""
|
|
|
|
import base64
|
|
from typing import Any
|
|
from urllib.parse import urlencode
|
|
|
|
import httpx
|
|
from fastapi.responses import RedirectResponse
|
|
|
|
_discovery_cache: dict[str, dict] = {}
|
|
|
|
|
|
async def oauth_resolve_provider_urls(
|
|
discovery_url: str,
|
|
) -> tuple[str, str, str | None]:
|
|
"""Fetch the OIDC discovery document and return endpoint URLs.
|
|
|
|
Args:
|
|
discovery_url: URL of the provider's ``/.well-known/openid-configuration``.
|
|
|
|
Returns:
|
|
A ``(authorization_url, token_url, userinfo_url)`` tuple.
|
|
*userinfo_url* is ``None`` when the provider does not advertise one.
|
|
"""
|
|
if discovery_url not in _discovery_cache:
|
|
async with httpx.AsyncClient() as client:
|
|
resp = await client.get(discovery_url)
|
|
resp.raise_for_status()
|
|
_discovery_cache[discovery_url] = resp.json()
|
|
cfg = _discovery_cache[discovery_url]
|
|
return (
|
|
cfg["authorization_endpoint"],
|
|
cfg["token_endpoint"],
|
|
cfg.get("userinfo_endpoint"),
|
|
)
|
|
|
|
|
|
async def oauth_fetch_userinfo(
|
|
*,
|
|
token_url: str,
|
|
userinfo_url: str,
|
|
code: str,
|
|
client_id: str,
|
|
client_secret: str,
|
|
redirect_uri: str,
|
|
) -> dict[str, Any]:
|
|
"""Exchange an authorization code for tokens and return the userinfo payload.
|
|
|
|
Performs the two-step OAuth 2.0 / OIDC token exchange:
|
|
|
|
1. POSTs the authorization *code* to *token_url* to obtain an access token.
|
|
2. GETs *userinfo_url* using that access token as a Bearer credential.
|
|
|
|
Args:
|
|
token_url: Provider's token endpoint.
|
|
userinfo_url: Provider's userinfo endpoint.
|
|
code: Authorization code received from the provider's callback.
|
|
client_id: OAuth application client ID.
|
|
client_secret: OAuth application client secret.
|
|
redirect_uri: Redirect URI that was used in the authorization request.
|
|
|
|
Returns:
|
|
The JSON payload returned by the userinfo endpoint as a plain ``dict``.
|
|
"""
|
|
async with httpx.AsyncClient() as client:
|
|
token_resp = await client.post(
|
|
token_url,
|
|
data={
|
|
"grant_type": "authorization_code",
|
|
"code": code,
|
|
"client_id": client_id,
|
|
"client_secret": client_secret,
|
|
"redirect_uri": redirect_uri,
|
|
},
|
|
headers={"Accept": "application/json"},
|
|
)
|
|
token_resp.raise_for_status()
|
|
access_token = token_resp.json()["access_token"]
|
|
|
|
userinfo_resp = await client.get(
|
|
userinfo_url,
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
userinfo_resp.raise_for_status()
|
|
return userinfo_resp.json()
|
|
|
|
|
|
def oauth_build_authorization_redirect(
|
|
authorization_url: str,
|
|
*,
|
|
client_id: str,
|
|
scopes: str,
|
|
redirect_uri: str,
|
|
destination: str,
|
|
) -> RedirectResponse:
|
|
"""Return an OAuth 2.0 authorization ``RedirectResponse``.
|
|
|
|
Args:
|
|
authorization_url: Provider's authorization endpoint.
|
|
client_id: OAuth application client ID.
|
|
scopes: Space-separated list of requested scopes.
|
|
redirect_uri: URI the provider should redirect back to after authorization.
|
|
destination: URL the user should be sent to after the full OAuth flow
|
|
completes (encoded as ``state``).
|
|
|
|
Returns:
|
|
A :class:`~fastapi.responses.RedirectResponse` to the provider's
|
|
authorization page.
|
|
"""
|
|
params = urlencode(
|
|
{
|
|
"client_id": client_id,
|
|
"response_type": "code",
|
|
"scope": scopes,
|
|
"redirect_uri": redirect_uri,
|
|
"state": oauth_encode_state(destination),
|
|
}
|
|
)
|
|
return RedirectResponse(f"{authorization_url}?{params}")
|
|
|
|
|
|
def oauth_encode_state(url: str) -> str:
|
|
"""Base64url-encode a URL to embed as an OAuth ``state`` parameter."""
|
|
return base64.urlsafe_b64encode(url.encode()).decode()
|
|
|
|
|
|
def oauth_decode_state(state: str | None, *, fallback: str) -> str:
|
|
"""Decode a base64url OAuth ``state`` parameter.
|
|
|
|
Handles missing padding (some providers strip ``=``).
|
|
Returns *fallback* if *state* is absent, the literal string ``"null"``,
|
|
or cannot be decoded.
|
|
"""
|
|
if not state or state == "null":
|
|
return fallback
|
|
try:
|
|
padded = state + "=" * (4 - len(state) % 4)
|
|
return base64.urlsafe_b64decode(padded).decode()
|
|
except Exception:
|
|
return fallback
|