mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-16 14:46:24 +02:00
feat(security): add oauth helpers
This commit is contained in:
@@ -1,7 +1,13 @@
|
|||||||
"""Authentication helpers for FastAPI using Security()."""
|
"""Authentication helpers for FastAPI using Security()."""
|
||||||
|
|
||||||
from .abc import AuthSource
|
from .abc import AuthSource
|
||||||
from .oauth import decode_oauth_state, encode_oauth_state
|
from .oauth import (
|
||||||
|
oauth_build_authorization_redirect,
|
||||||
|
oauth_decode_state,
|
||||||
|
oauth_encode_state,
|
||||||
|
oauth_fetch_userinfo,
|
||||||
|
oauth_resolve_provider_urls,
|
||||||
|
)
|
||||||
from .sources import APIKeyHeaderAuth, BearerTokenAuth, CookieAuth, MultiAuth
|
from .sources import APIKeyHeaderAuth, BearerTokenAuth, CookieAuth, MultiAuth
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -10,6 +16,9 @@ __all__ = [
|
|||||||
"BearerTokenAuth",
|
"BearerTokenAuth",
|
||||||
"CookieAuth",
|
"CookieAuth",
|
||||||
"MultiAuth",
|
"MultiAuth",
|
||||||
"decode_oauth_state",
|
"oauth_build_authorization_redirect",
|
||||||
"encode_oauth_state",
|
"oauth_decode_state",
|
||||||
|
"oauth_encode_state",
|
||||||
|
"oauth_fetch_userinfo",
|
||||||
|
"oauth_resolve_provider_urls",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,14 +1,130 @@
|
|||||||
"""OAuth 2.0 / OIDC helper utilities."""
|
"""OAuth 2.0 / OIDC helper utilities."""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
|
from typing import Any
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
|
||||||
|
_discovery_cache: dict[str, dict] = {}
|
||||||
|
|
||||||
|
|
||||||
def encode_oauth_state(url: str) -> str:
|
async def oauth_resolve_provider_urls(
|
||||||
|
discovery_url: str,
|
||||||
|
) -> tuple[str, str, str | None]:
|
||||||
|
"""Fetch the OIDC discovery document and return endpoint URLs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
discovery_url: URL of the provider's ``/.well-known/openid-configuration``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A ``(authorization_url, token_url, userinfo_url)`` tuple.
|
||||||
|
*userinfo_url* is ``None`` when the provider does not advertise one.
|
||||||
|
"""
|
||||||
|
if discovery_url not in _discovery_cache:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
resp = await client.get(discovery_url)
|
||||||
|
resp.raise_for_status()
|
||||||
|
_discovery_cache[discovery_url] = resp.json()
|
||||||
|
cfg = _discovery_cache[discovery_url]
|
||||||
|
return (
|
||||||
|
cfg["authorization_endpoint"],
|
||||||
|
cfg["token_endpoint"],
|
||||||
|
cfg.get("userinfo_endpoint"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def oauth_fetch_userinfo(
|
||||||
|
*,
|
||||||
|
token_url: str,
|
||||||
|
userinfo_url: str,
|
||||||
|
code: str,
|
||||||
|
client_id: str,
|
||||||
|
client_secret: str,
|
||||||
|
redirect_uri: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Exchange an authorization code for tokens and return the userinfo payload.
|
||||||
|
|
||||||
|
Performs the two-step OAuth 2.0 / OIDC token exchange:
|
||||||
|
|
||||||
|
1. POSTs the authorization *code* to *token_url* to obtain an access token.
|
||||||
|
2. GETs *userinfo_url* using that access token as a Bearer credential.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_url: Provider's token endpoint.
|
||||||
|
userinfo_url: Provider's userinfo endpoint.
|
||||||
|
code: Authorization code received from the provider's callback.
|
||||||
|
client_id: OAuth application client ID.
|
||||||
|
client_secret: OAuth application client secret.
|
||||||
|
redirect_uri: Redirect URI that was used in the authorization request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The JSON payload returned by the userinfo endpoint as a plain ``dict``.
|
||||||
|
"""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
token_resp = await client.post(
|
||||||
|
token_url,
|
||||||
|
data={
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"code": code,
|
||||||
|
"client_id": client_id,
|
||||||
|
"client_secret": client_secret,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
},
|
||||||
|
headers={"Accept": "application/json"},
|
||||||
|
)
|
||||||
|
token_resp.raise_for_status()
|
||||||
|
access_token = token_resp.json()["access_token"]
|
||||||
|
|
||||||
|
userinfo_resp = await client.get(
|
||||||
|
userinfo_url,
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
userinfo_resp.raise_for_status()
|
||||||
|
return userinfo_resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
def oauth_build_authorization_redirect(
|
||||||
|
authorization_url: str,
|
||||||
|
*,
|
||||||
|
client_id: str,
|
||||||
|
scopes: str,
|
||||||
|
redirect_uri: str,
|
||||||
|
destination: str,
|
||||||
|
) -> RedirectResponse:
|
||||||
|
"""Return an OAuth 2.0 authorization ``RedirectResponse``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
authorization_url: Provider's authorization endpoint.
|
||||||
|
client_id: OAuth application client ID.
|
||||||
|
scopes: Space-separated list of requested scopes.
|
||||||
|
redirect_uri: URI the provider should redirect back to after authorization.
|
||||||
|
destination: URL the user should be sent to after the full OAuth flow
|
||||||
|
completes (encoded as ``state``).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A :class:`~fastapi.responses.RedirectResponse` to the provider's
|
||||||
|
authorization page.
|
||||||
|
"""
|
||||||
|
params = urlencode(
|
||||||
|
{
|
||||||
|
"client_id": client_id,
|
||||||
|
"response_type": "code",
|
||||||
|
"scope": scopes,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"state": oauth_encode_state(destination),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return RedirectResponse(f"{authorization_url}?{params}")
|
||||||
|
|
||||||
|
|
||||||
|
def oauth_encode_state(url: str) -> str:
|
||||||
"""Base64url-encode a URL to embed as an OAuth ``state`` parameter."""
|
"""Base64url-encode a URL to embed as an OAuth ``state`` parameter."""
|
||||||
return base64.urlsafe_b64encode(url.encode()).decode()
|
return base64.urlsafe_b64encode(url.encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
def decode_oauth_state(state: str | None, *, fallback: str) -> str:
|
def oauth_decode_state(state: str | None, *, fallback: str) -> str:
|
||||||
"""Decode a base64url OAuth ``state`` parameter.
|
"""Decode a base64url OAuth ``state`` parameter.
|
||||||
|
|
||||||
Handles missing padding (some providers strip ``=``).
|
Handles missing padding (some providers strip ``=``).
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
"""Tests for fastapi_toolsets.security."""
|
"""Tests for fastapi_toolsets.security."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import FastAPI, Security
|
from fastapi import FastAPI, Security
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
@@ -11,6 +14,11 @@ from fastapi_toolsets.security import (
|
|||||||
BearerTokenAuth,
|
BearerTokenAuth,
|
||||||
CookieAuth,
|
CookieAuth,
|
||||||
MultiAuth,
|
MultiAuth,
|
||||||
|
oauth_build_authorization_redirect,
|
||||||
|
oauth_decode_state,
|
||||||
|
oauth_encode_state,
|
||||||
|
oauth_fetch_userinfo,
|
||||||
|
oauth_resolve_provider_urls,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -129,8 +137,6 @@ class TestBearerTokenAuth:
|
|||||||
response = client.get("/me", headers={"Authorization": "Bearer org_abc123"})
|
response = client.get("/me", headers={"Authorization": "Bearer org_abc123"})
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
|
|
||||||
# --- extract() ---
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_extract_no_header(self):
|
async def test_extract_no_header(self):
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
@@ -196,8 +202,6 @@ class TestBearerTokenAuth:
|
|||||||
request = Request(scope)
|
request = Request(scope)
|
||||||
assert await bearer.extract(request) is None
|
assert await bearer.extract(request) is None
|
||||||
|
|
||||||
# --- generate_token() ---
|
|
||||||
|
|
||||||
def test_generate_token_no_prefix(self):
|
def test_generate_token_no_prefix(self):
|
||||||
bearer = BearerTokenAuth(simple_validator)
|
bearer = BearerTokenAuth(simple_validator)
|
||||||
token = bearer.generate_token()
|
token = bearer.generate_token()
|
||||||
@@ -962,3 +966,209 @@ class TestAuthSource:
|
|||||||
response = client.get("/me", headers={"X-Token": "s3cr3t"})
|
response = client.get("/me", headers={"X-Token": "s3cr3t"})
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == {"token": "s3cr3t"}
|
assert response.json() == {"token": "s3cr3t"}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_async_client_mock(get_return=None, post_return=None):
|
||||||
|
"""Return a patched httpx.AsyncClient context-manager mock."""
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
if get_return is not None:
|
||||||
|
mock_client.get.return_value = get_return
|
||||||
|
if post_return is not None:
|
||||||
|
mock_client.post.return_value = post_return
|
||||||
|
cm = MagicMock()
|
||||||
|
cm.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
cm.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
return cm, mock_client
|
||||||
|
|
||||||
|
|
||||||
|
class TestEncodeDecodeOAuthState:
|
||||||
|
def test_encode_returns_base64url_string(self):
|
||||||
|
result = oauth_encode_state("https://example.com/dashboard")
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert "+" not in result
|
||||||
|
assert "/" not in result
|
||||||
|
|
||||||
|
def test_round_trip(self):
|
||||||
|
url = "https://example.com/after-login?next=/home"
|
||||||
|
assert oauth_decode_state(oauth_encode_state(url), fallback="/") == url
|
||||||
|
|
||||||
|
def test_decode_none_returns_fallback(self):
|
||||||
|
assert oauth_decode_state(None, fallback="/home") == "/home"
|
||||||
|
|
||||||
|
def test_decode_null_string_returns_fallback(self):
|
||||||
|
assert oauth_decode_state("null", fallback="/home") == "/home"
|
||||||
|
|
||||||
|
def test_decode_invalid_base64_returns_fallback(self):
|
||||||
|
assert oauth_decode_state("!!!notbase64!!!", fallback="/home") == "/home"
|
||||||
|
|
||||||
|
def test_decode_handles_missing_padding(self):
|
||||||
|
url = "https://example.com/x"
|
||||||
|
encoded = oauth_encode_state(url).rstrip("=")
|
||||||
|
assert oauth_decode_state(encoded, fallback="/") == url
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildAuthorizationRedirect:
|
||||||
|
def test_returns_redirect_response(self):
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
|
||||||
|
response = oauth_build_authorization_redirect(
|
||||||
|
"https://auth.example.com/authorize",
|
||||||
|
client_id="my-client",
|
||||||
|
scopes="openid email",
|
||||||
|
redirect_uri="https://app.example.com/callback",
|
||||||
|
destination="https://app.example.com/dashboard",
|
||||||
|
)
|
||||||
|
assert isinstance(response, RedirectResponse)
|
||||||
|
|
||||||
|
def test_redirect_location_contains_all_params(self):
|
||||||
|
response = oauth_build_authorization_redirect(
|
||||||
|
"https://auth.example.com/authorize",
|
||||||
|
client_id="my-client",
|
||||||
|
scopes="openid email",
|
||||||
|
redirect_uri="https://app.example.com/callback",
|
||||||
|
destination="https://app.example.com/dashboard",
|
||||||
|
)
|
||||||
|
location = response.headers["location"]
|
||||||
|
parsed = urlparse(location)
|
||||||
|
assert (
|
||||||
|
parsed.scheme + "://" + parsed.netloc + parsed.path
|
||||||
|
== "https://auth.example.com/authorize"
|
||||||
|
)
|
||||||
|
params = parse_qs(parsed.query)
|
||||||
|
assert params["client_id"] == ["my-client"]
|
||||||
|
assert params["response_type"] == ["code"]
|
||||||
|
assert params["scope"] == ["openid email"]
|
||||||
|
assert params["redirect_uri"] == ["https://app.example.com/callback"]
|
||||||
|
assert (
|
||||||
|
oauth_decode_state(params["state"][0], fallback="")
|
||||||
|
== "https://app.example.com/dashboard"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestResolveProviderUrls:
|
||||||
|
def _discovery(self, *, userinfo=True):
|
||||||
|
doc = {
|
||||||
|
"authorization_endpoint": "https://auth.example.com/authorize",
|
||||||
|
"token_endpoint": "https://auth.example.com/token",
|
||||||
|
}
|
||||||
|
if userinfo:
|
||||||
|
doc["userinfo_endpoint"] = "https://auth.example.com/userinfo"
|
||||||
|
return doc
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_returns_all_endpoints(self):
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.raise_for_status = MagicMock()
|
||||||
|
mock_resp.json.return_value = self._discovery()
|
||||||
|
cm, mock_client = _make_async_client_mock(get_return=mock_resp)
|
||||||
|
|
||||||
|
with patch("fastapi_toolsets.security.oauth._discovery_cache", {}):
|
||||||
|
with patch("httpx.AsyncClient", return_value=cm):
|
||||||
|
auth_url, token_url, userinfo_url = await oauth_resolve_provider_urls(
|
||||||
|
"https://auth.example.com/.well-known/openid-configuration"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert auth_url == "https://auth.example.com/authorize"
|
||||||
|
assert token_url == "https://auth.example.com/token"
|
||||||
|
assert userinfo_url == "https://auth.example.com/userinfo"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_userinfo_url_none_when_absent(self):
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.raise_for_status = MagicMock()
|
||||||
|
mock_resp.json.return_value = self._discovery(userinfo=False)
|
||||||
|
cm, mock_client = _make_async_client_mock(get_return=mock_resp)
|
||||||
|
|
||||||
|
with patch("fastapi_toolsets.security.oauth._discovery_cache", {}):
|
||||||
|
with patch("httpx.AsyncClient", return_value=cm):
|
||||||
|
_, _, userinfo_url = await oauth_resolve_provider_urls(
|
||||||
|
"https://auth.example.com/.well-known/openid-configuration"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert userinfo_url is None
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_caches_discovery_document(self):
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.raise_for_status = MagicMock()
|
||||||
|
mock_resp.json.return_value = self._discovery()
|
||||||
|
cm, mock_client = _make_async_client_mock(get_return=mock_resp)
|
||||||
|
|
||||||
|
url = "https://auth.example.com/.well-known/openid-configuration"
|
||||||
|
with patch("fastapi_toolsets.security.oauth._discovery_cache", {}):
|
||||||
|
with patch("httpx.AsyncClient", return_value=cm):
|
||||||
|
await oauth_resolve_provider_urls(url)
|
||||||
|
await oauth_resolve_provider_urls(url)
|
||||||
|
|
||||||
|
assert mock_client.get.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestFetchUserinfo:
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_returns_userinfo_payload(self):
|
||||||
|
token_resp = MagicMock()
|
||||||
|
token_resp.raise_for_status = MagicMock()
|
||||||
|
token_resp.json.return_value = {"access_token": "tok123"}
|
||||||
|
|
||||||
|
userinfo_resp = MagicMock()
|
||||||
|
userinfo_resp.raise_for_status = MagicMock()
|
||||||
|
userinfo_resp.json.return_value = {
|
||||||
|
"sub": "user-1",
|
||||||
|
"email": "alice@example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
cm, mock_client = _make_async_client_mock(
|
||||||
|
post_return=token_resp, get_return=userinfo_resp
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=cm):
|
||||||
|
result = await oauth_fetch_userinfo(
|
||||||
|
token_url="https://auth.example.com/token",
|
||||||
|
userinfo_url="https://auth.example.com/userinfo",
|
||||||
|
code="authcode123",
|
||||||
|
client_id="client-id",
|
||||||
|
client_secret="client-secret",
|
||||||
|
redirect_uri="https://app.example.com/callback",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {"sub": "user-1", "email": "alice@example.com"}
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_posts_correct_token_request_and_uses_bearer(self):
|
||||||
|
token_resp = MagicMock()
|
||||||
|
token_resp.raise_for_status = MagicMock()
|
||||||
|
token_resp.json.return_value = {"access_token": "tok123"}
|
||||||
|
|
||||||
|
userinfo_resp = MagicMock()
|
||||||
|
userinfo_resp.raise_for_status = MagicMock()
|
||||||
|
userinfo_resp.json.return_value = {}
|
||||||
|
|
||||||
|
cm, mock_client = _make_async_client_mock(
|
||||||
|
post_return=token_resp, get_return=userinfo_resp
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=cm):
|
||||||
|
await oauth_fetch_userinfo(
|
||||||
|
token_url="https://auth.example.com/token",
|
||||||
|
userinfo_url="https://auth.example.com/userinfo",
|
||||||
|
code="authcode123",
|
||||||
|
client_id="my-client",
|
||||||
|
client_secret="my-secret",
|
||||||
|
redirect_uri="https://app.example.com/callback",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client.post.assert_called_once_with(
|
||||||
|
"https://auth.example.com/token",
|
||||||
|
data={
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"code": "authcode123",
|
||||||
|
"client_id": "my-client",
|
||||||
|
"client_secret": "my-secret",
|
||||||
|
"redirect_uri": "https://app.example.com/callback",
|
||||||
|
},
|
||||||
|
headers={"Accept": "application/json"},
|
||||||
|
)
|
||||||
|
mock_client.get.assert_called_once_with(
|
||||||
|
"https://auth.example.com/userinfo",
|
||||||
|
headers={"Authorization": "Bearer tok123"},
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user