feat(security): add oauth helpers

This commit is contained in:
2026-03-07 10:29:20 -05:00
parent 0bc025b844
commit 83c1f98d25
3 changed files with 344 additions and 9 deletions

View File

@@ -1,7 +1,13 @@
"""Authentication helpers for FastAPI using Security()."""
from .abc import AuthSource
from .oauth import decode_oauth_state, encode_oauth_state
from .oauth import (
oauth_build_authorization_redirect,
oauth_decode_state,
oauth_encode_state,
oauth_fetch_userinfo,
oauth_resolve_provider_urls,
)
from .sources import APIKeyHeaderAuth, BearerTokenAuth, CookieAuth, MultiAuth
__all__ = [
@@ -10,6 +16,9 @@ __all__ = [
"BearerTokenAuth",
"CookieAuth",
"MultiAuth",
"decode_oauth_state",
"encode_oauth_state",
"oauth_build_authorization_redirect",
"oauth_decode_state",
"oauth_encode_state",
"oauth_fetch_userinfo",
"oauth_resolve_provider_urls",
]

View File

@@ -1,14 +1,130 @@
"""OAuth 2.0 / OIDC helper utilities."""
import base64
from typing import Any
from urllib.parse import urlencode
import httpx
from fastapi.responses import RedirectResponse
_discovery_cache: dict[str, dict] = {}
def encode_oauth_state(url: str) -> str:
async def oauth_resolve_provider_urls(
discovery_url: str,
) -> tuple[str, str, str | None]:
"""Fetch the OIDC discovery document and return endpoint URLs.
Args:
discovery_url: URL of the provider's ``/.well-known/openid-configuration``.
Returns:
A ``(authorization_url, token_url, userinfo_url)`` tuple.
*userinfo_url* is ``None`` when the provider does not advertise one.
"""
if discovery_url not in _discovery_cache:
async with httpx.AsyncClient() as client:
resp = await client.get(discovery_url)
resp.raise_for_status()
_discovery_cache[discovery_url] = resp.json()
cfg = _discovery_cache[discovery_url]
return (
cfg["authorization_endpoint"],
cfg["token_endpoint"],
cfg.get("userinfo_endpoint"),
)
async def oauth_fetch_userinfo(
*,
token_url: str,
userinfo_url: str,
code: str,
client_id: str,
client_secret: str,
redirect_uri: str,
) -> dict[str, Any]:
"""Exchange an authorization code for tokens and return the userinfo payload.
Performs the two-step OAuth 2.0 / OIDC token exchange:
1. POSTs the authorization *code* to *token_url* to obtain an access token.
2. GETs *userinfo_url* using that access token as a Bearer credential.
Args:
token_url: Provider's token endpoint.
userinfo_url: Provider's userinfo endpoint.
code: Authorization code received from the provider's callback.
client_id: OAuth application client ID.
client_secret: OAuth application client secret.
redirect_uri: Redirect URI that was used in the authorization request.
Returns:
The JSON payload returned by the userinfo endpoint as a plain ``dict``.
"""
async with httpx.AsyncClient() as client:
token_resp = await client.post(
token_url,
data={
"grant_type": "authorization_code",
"code": code,
"client_id": client_id,
"client_secret": client_secret,
"redirect_uri": redirect_uri,
},
headers={"Accept": "application/json"},
)
token_resp.raise_for_status()
access_token = token_resp.json()["access_token"]
userinfo_resp = await client.get(
userinfo_url,
headers={"Authorization": f"Bearer {access_token}"},
)
userinfo_resp.raise_for_status()
return userinfo_resp.json()
def oauth_build_authorization_redirect(
authorization_url: str,
*,
client_id: str,
scopes: str,
redirect_uri: str,
destination: str,
) -> RedirectResponse:
"""Return an OAuth 2.0 authorization ``RedirectResponse``.
Args:
authorization_url: Provider's authorization endpoint.
client_id: OAuth application client ID.
scopes: Space-separated list of requested scopes.
redirect_uri: URI the provider should redirect back to after authorization.
destination: URL the user should be sent to after the full OAuth flow
completes (encoded as ``state``).
Returns:
A :class:`~fastapi.responses.RedirectResponse` to the provider's
authorization page.
"""
params = urlencode(
{
"client_id": client_id,
"response_type": "code",
"scope": scopes,
"redirect_uri": redirect_uri,
"state": oauth_encode_state(destination),
}
)
return RedirectResponse(f"{authorization_url}?{params}")
def oauth_encode_state(url: str) -> str:
"""Base64url-encode a URL to embed as an OAuth ``state`` parameter."""
return base64.urlsafe_b64encode(url.encode()).decode()
def decode_oauth_state(state: str | None, *, fallback: str) -> str:
def oauth_decode_state(state: str | None, *, fallback: str) -> str:
"""Decode a base64url OAuth ``state`` parameter.
Handles missing padding (some providers strip ``=``).

View File

@@ -1,5 +1,8 @@
"""Tests for fastapi_toolsets.security."""
from unittest.mock import AsyncMock, MagicMock, patch
from urllib.parse import parse_qs, urlparse
import pytest
from fastapi import FastAPI, Security
from fastapi.testclient import TestClient
@@ -11,6 +14,11 @@ from fastapi_toolsets.security import (
BearerTokenAuth,
CookieAuth,
MultiAuth,
oauth_build_authorization_redirect,
oauth_decode_state,
oauth_encode_state,
oauth_fetch_userinfo,
oauth_resolve_provider_urls,
)
@@ -129,8 +137,6 @@ class TestBearerTokenAuth:
response = client.get("/me", headers={"Authorization": "Bearer org_abc123"})
assert response.status_code == 401
# --- extract() ---
@pytest.mark.anyio
async def test_extract_no_header(self):
from starlette.requests import Request
@@ -196,8 +202,6 @@ class TestBearerTokenAuth:
request = Request(scope)
assert await bearer.extract(request) is None
# --- generate_token() ---
def test_generate_token_no_prefix(self):
bearer = BearerTokenAuth(simple_validator)
token = bearer.generate_token()
@@ -962,3 +966,209 @@ class TestAuthSource:
response = client.get("/me", headers={"X-Token": "s3cr3t"})
assert response.status_code == 200
assert response.json() == {"token": "s3cr3t"}
def _make_async_client_mock(get_return=None, post_return=None):
"""Return a patched httpx.AsyncClient context-manager mock."""
mock_client = AsyncMock()
if get_return is not None:
mock_client.get.return_value = get_return
if post_return is not None:
mock_client.post.return_value = post_return
cm = MagicMock()
cm.__aenter__ = AsyncMock(return_value=mock_client)
cm.__aexit__ = AsyncMock(return_value=None)
return cm, mock_client
class TestEncodeDecodeOAuthState:
def test_encode_returns_base64url_string(self):
result = oauth_encode_state("https://example.com/dashboard")
assert isinstance(result, str)
assert "+" not in result
assert "/" not in result
def test_round_trip(self):
url = "https://example.com/after-login?next=/home"
assert oauth_decode_state(oauth_encode_state(url), fallback="/") == url
def test_decode_none_returns_fallback(self):
assert oauth_decode_state(None, fallback="/home") == "/home"
def test_decode_null_string_returns_fallback(self):
assert oauth_decode_state("null", fallback="/home") == "/home"
def test_decode_invalid_base64_returns_fallback(self):
assert oauth_decode_state("!!!notbase64!!!", fallback="/home") == "/home"
def test_decode_handles_missing_padding(self):
url = "https://example.com/x"
encoded = oauth_encode_state(url).rstrip("=")
assert oauth_decode_state(encoded, fallback="/") == url
class TestBuildAuthorizationRedirect:
def test_returns_redirect_response(self):
from fastapi.responses import RedirectResponse
response = oauth_build_authorization_redirect(
"https://auth.example.com/authorize",
client_id="my-client",
scopes="openid email",
redirect_uri="https://app.example.com/callback",
destination="https://app.example.com/dashboard",
)
assert isinstance(response, RedirectResponse)
def test_redirect_location_contains_all_params(self):
response = oauth_build_authorization_redirect(
"https://auth.example.com/authorize",
client_id="my-client",
scopes="openid email",
redirect_uri="https://app.example.com/callback",
destination="https://app.example.com/dashboard",
)
location = response.headers["location"]
parsed = urlparse(location)
assert (
parsed.scheme + "://" + parsed.netloc + parsed.path
== "https://auth.example.com/authorize"
)
params = parse_qs(parsed.query)
assert params["client_id"] == ["my-client"]
assert params["response_type"] == ["code"]
assert params["scope"] == ["openid email"]
assert params["redirect_uri"] == ["https://app.example.com/callback"]
assert (
oauth_decode_state(params["state"][0], fallback="")
== "https://app.example.com/dashboard"
)
class TestResolveProviderUrls:
def _discovery(self, *, userinfo=True):
doc = {
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token",
}
if userinfo:
doc["userinfo_endpoint"] = "https://auth.example.com/userinfo"
return doc
@pytest.mark.anyio
async def test_returns_all_endpoints(self):
mock_resp = MagicMock()
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = self._discovery()
cm, mock_client = _make_async_client_mock(get_return=mock_resp)
with patch("fastapi_toolsets.security.oauth._discovery_cache", {}):
with patch("httpx.AsyncClient", return_value=cm):
auth_url, token_url, userinfo_url = await oauth_resolve_provider_urls(
"https://auth.example.com/.well-known/openid-configuration"
)
assert auth_url == "https://auth.example.com/authorize"
assert token_url == "https://auth.example.com/token"
assert userinfo_url == "https://auth.example.com/userinfo"
@pytest.mark.anyio
async def test_userinfo_url_none_when_absent(self):
mock_resp = MagicMock()
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = self._discovery(userinfo=False)
cm, mock_client = _make_async_client_mock(get_return=mock_resp)
with patch("fastapi_toolsets.security.oauth._discovery_cache", {}):
with patch("httpx.AsyncClient", return_value=cm):
_, _, userinfo_url = await oauth_resolve_provider_urls(
"https://auth.example.com/.well-known/openid-configuration"
)
assert userinfo_url is None
@pytest.mark.anyio
async def test_caches_discovery_document(self):
mock_resp = MagicMock()
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = self._discovery()
cm, mock_client = _make_async_client_mock(get_return=mock_resp)
url = "https://auth.example.com/.well-known/openid-configuration"
with patch("fastapi_toolsets.security.oauth._discovery_cache", {}):
with patch("httpx.AsyncClient", return_value=cm):
await oauth_resolve_provider_urls(url)
await oauth_resolve_provider_urls(url)
assert mock_client.get.call_count == 1
class TestFetchUserinfo:
@pytest.mark.anyio
async def test_returns_userinfo_payload(self):
token_resp = MagicMock()
token_resp.raise_for_status = MagicMock()
token_resp.json.return_value = {"access_token": "tok123"}
userinfo_resp = MagicMock()
userinfo_resp.raise_for_status = MagicMock()
userinfo_resp.json.return_value = {
"sub": "user-1",
"email": "alice@example.com",
}
cm, mock_client = _make_async_client_mock(
post_return=token_resp, get_return=userinfo_resp
)
with patch("httpx.AsyncClient", return_value=cm):
result = await oauth_fetch_userinfo(
token_url="https://auth.example.com/token",
userinfo_url="https://auth.example.com/userinfo",
code="authcode123",
client_id="client-id",
client_secret="client-secret",
redirect_uri="https://app.example.com/callback",
)
assert result == {"sub": "user-1", "email": "alice@example.com"}
@pytest.mark.anyio
async def test_posts_correct_token_request_and_uses_bearer(self):
token_resp = MagicMock()
token_resp.raise_for_status = MagicMock()
token_resp.json.return_value = {"access_token": "tok123"}
userinfo_resp = MagicMock()
userinfo_resp.raise_for_status = MagicMock()
userinfo_resp.json.return_value = {}
cm, mock_client = _make_async_client_mock(
post_return=token_resp, get_return=userinfo_resp
)
with patch("httpx.AsyncClient", return_value=cm):
await oauth_fetch_userinfo(
token_url="https://auth.example.com/token",
userinfo_url="https://auth.example.com/userinfo",
code="authcode123",
client_id="my-client",
client_secret="my-secret",
redirect_uri="https://app.example.com/callback",
)
mock_client.post.assert_called_once_with(
"https://auth.example.com/token",
data={
"grant_type": "authorization_code",
"code": "authcode123",
"client_id": "my-client",
"client_secret": "my-secret",
"redirect_uri": "https://app.example.com/callback",
},
headers={"Accept": "application/json"},
)
mock_client.get.assert_called_once_with(
"https://auth.example.com/userinfo",
headers={"Authorization": "Bearer tok123"},
)