From 233362dd35ec72c5278a33d933ded2f36dd20066 Mon Sep 17 00:00:00 2001 From: d3vyce Date: Sat, 7 Mar 2026 10:29:20 -0500 Subject: [PATCH] feat(security): add oauth helpers --- src/fastapi_toolsets/security/__init__.py | 15 +- src/fastapi_toolsets/security/oauth.py | 120 +++++++++++- tests/test_security.py | 218 +++++++++++++++++++++- 3 files changed, 344 insertions(+), 9 deletions(-) diff --git a/src/fastapi_toolsets/security/__init__.py b/src/fastapi_toolsets/security/__init__.py index 3d5505d..483b49b 100644 --- a/src/fastapi_toolsets/security/__init__.py +++ b/src/fastapi_toolsets/security/__init__.py @@ -1,7 +1,13 @@ """Authentication helpers for FastAPI using Security().""" from .abc import AuthSource -from .oauth import decode_oauth_state, encode_oauth_state +from .oauth import ( + oauth_build_authorization_redirect, + oauth_decode_state, + oauth_encode_state, + oauth_fetch_userinfo, + oauth_resolve_provider_urls, +) from .sources import APIKeyHeaderAuth, BearerTokenAuth, CookieAuth, MultiAuth __all__ = [ @@ -10,6 +16,9 @@ __all__ = [ "BearerTokenAuth", "CookieAuth", "MultiAuth", - "decode_oauth_state", - "encode_oauth_state", + "oauth_build_authorization_redirect", + "oauth_decode_state", + "oauth_encode_state", + "oauth_fetch_userinfo", + "oauth_resolve_provider_urls", ] diff --git a/src/fastapi_toolsets/security/oauth.py b/src/fastapi_toolsets/security/oauth.py index 4e95281..f06c467 100644 --- a/src/fastapi_toolsets/security/oauth.py +++ b/src/fastapi_toolsets/security/oauth.py @@ -1,14 +1,130 @@ """OAuth 2.0 / OIDC helper utilities.""" import base64 +from typing import Any +from urllib.parse import urlencode + +import httpx +from fastapi.responses import RedirectResponse + +_discovery_cache: dict[str, dict] = {} -def encode_oauth_state(url: str) -> str: +async def oauth_resolve_provider_urls( + discovery_url: str, +) -> tuple[str, str, str | None]: + """Fetch the OIDC discovery document and return endpoint URLs. + + Args: + discovery_url: URL of the provider's ``/.well-known/openid-configuration``. + + Returns: + A ``(authorization_url, token_url, userinfo_url)`` tuple. + *userinfo_url* is ``None`` when the provider does not advertise one. + """ + if discovery_url not in _discovery_cache: + async with httpx.AsyncClient() as client: + resp = await client.get(discovery_url) + resp.raise_for_status() + _discovery_cache[discovery_url] = resp.json() + cfg = _discovery_cache[discovery_url] + return ( + cfg["authorization_endpoint"], + cfg["token_endpoint"], + cfg.get("userinfo_endpoint"), + ) + + +async def oauth_fetch_userinfo( + *, + token_url: str, + userinfo_url: str, + code: str, + client_id: str, + client_secret: str, + redirect_uri: str, +) -> dict[str, Any]: + """Exchange an authorization code for tokens and return the userinfo payload. + + Performs the two-step OAuth 2.0 / OIDC token exchange: + + 1. POSTs the authorization *code* to *token_url* to obtain an access token. + 2. GETs *userinfo_url* using that access token as a Bearer credential. + + Args: + token_url: Provider's token endpoint. + userinfo_url: Provider's userinfo endpoint. + code: Authorization code received from the provider's callback. + client_id: OAuth application client ID. + client_secret: OAuth application client secret. + redirect_uri: Redirect URI that was used in the authorization request. + + Returns: + The JSON payload returned by the userinfo endpoint as a plain ``dict``. + """ + async with httpx.AsyncClient() as client: + token_resp = await client.post( + token_url, + data={ + "grant_type": "authorization_code", + "code": code, + "client_id": client_id, + "client_secret": client_secret, + "redirect_uri": redirect_uri, + }, + headers={"Accept": "application/json"}, + ) + token_resp.raise_for_status() + access_token = token_resp.json()["access_token"] + + userinfo_resp = await client.get( + userinfo_url, + headers={"Authorization": f"Bearer {access_token}"}, + ) + userinfo_resp.raise_for_status() + return userinfo_resp.json() + + +def oauth_build_authorization_redirect( + authorization_url: str, + *, + client_id: str, + scopes: str, + redirect_uri: str, + destination: str, +) -> RedirectResponse: + """Return an OAuth 2.0 authorization ``RedirectResponse``. + + Args: + authorization_url: Provider's authorization endpoint. + client_id: OAuth application client ID. + scopes: Space-separated list of requested scopes. + redirect_uri: URI the provider should redirect back to after authorization. + destination: URL the user should be sent to after the full OAuth flow + completes (encoded as ``state``). + + Returns: + A :class:`~fastapi.responses.RedirectResponse` to the provider's + authorization page. + """ + params = urlencode( + { + "client_id": client_id, + "response_type": "code", + "scope": scopes, + "redirect_uri": redirect_uri, + "state": oauth_encode_state(destination), + } + ) + return RedirectResponse(f"{authorization_url}?{params}") + + +def oauth_encode_state(url: str) -> str: """Base64url-encode a URL to embed as an OAuth ``state`` parameter.""" return base64.urlsafe_b64encode(url.encode()).decode() -def decode_oauth_state(state: str | None, *, fallback: str) -> str: +def oauth_decode_state(state: str | None, *, fallback: str) -> str: """Decode a base64url OAuth ``state`` parameter. Handles missing padding (some providers strip ``=``). diff --git a/tests/test_security.py b/tests/test_security.py index 2e5d600..f2a7043 100644 --- a/tests/test_security.py +++ b/tests/test_security.py @@ -1,5 +1,8 @@ """Tests for fastapi_toolsets.security.""" +from unittest.mock import AsyncMock, MagicMock, patch +from urllib.parse import parse_qs, urlparse + import pytest from fastapi import FastAPI, Security from fastapi.testclient import TestClient @@ -11,6 +14,11 @@ from fastapi_toolsets.security import ( BearerTokenAuth, CookieAuth, MultiAuth, + oauth_build_authorization_redirect, + oauth_decode_state, + oauth_encode_state, + oauth_fetch_userinfo, + oauth_resolve_provider_urls, ) @@ -129,8 +137,6 @@ class TestBearerTokenAuth: response = client.get("/me", headers={"Authorization": "Bearer org_abc123"}) assert response.status_code == 401 - # --- extract() --- - @pytest.mark.anyio async def test_extract_no_header(self): from starlette.requests import Request @@ -196,8 +202,6 @@ class TestBearerTokenAuth: request = Request(scope) assert await bearer.extract(request) is None - # --- generate_token() --- - def test_generate_token_no_prefix(self): bearer = BearerTokenAuth(simple_validator) token = bearer.generate_token() @@ -962,3 +966,209 @@ class TestAuthSource: response = client.get("/me", headers={"X-Token": "s3cr3t"}) assert response.status_code == 200 assert response.json() == {"token": "s3cr3t"} + + +def _make_async_client_mock(get_return=None, post_return=None): + """Return a patched httpx.AsyncClient context-manager mock.""" + mock_client = AsyncMock() + if get_return is not None: + mock_client.get.return_value = get_return + if post_return is not None: + mock_client.post.return_value = post_return + cm = MagicMock() + cm.__aenter__ = AsyncMock(return_value=mock_client) + cm.__aexit__ = AsyncMock(return_value=None) + return cm, mock_client + + +class TestEncodeDecodeOAuthState: + def test_encode_returns_base64url_string(self): + result = oauth_encode_state("https://example.com/dashboard") + assert isinstance(result, str) + assert "+" not in result + assert "/" not in result + + def test_round_trip(self): + url = "https://example.com/after-login?next=/home" + assert oauth_decode_state(oauth_encode_state(url), fallback="/") == url + + def test_decode_none_returns_fallback(self): + assert oauth_decode_state(None, fallback="/home") == "/home" + + def test_decode_null_string_returns_fallback(self): + assert oauth_decode_state("null", fallback="/home") == "/home" + + def test_decode_invalid_base64_returns_fallback(self): + assert oauth_decode_state("!!!notbase64!!!", fallback="/home") == "/home" + + def test_decode_handles_missing_padding(self): + url = "https://example.com/x" + encoded = oauth_encode_state(url).rstrip("=") + assert oauth_decode_state(encoded, fallback="/") == url + + +class TestBuildAuthorizationRedirect: + def test_returns_redirect_response(self): + from fastapi.responses import RedirectResponse + + response = oauth_build_authorization_redirect( + "https://auth.example.com/authorize", + client_id="my-client", + scopes="openid email", + redirect_uri="https://app.example.com/callback", + destination="https://app.example.com/dashboard", + ) + assert isinstance(response, RedirectResponse) + + def test_redirect_location_contains_all_params(self): + response = oauth_build_authorization_redirect( + "https://auth.example.com/authorize", + client_id="my-client", + scopes="openid email", + redirect_uri="https://app.example.com/callback", + destination="https://app.example.com/dashboard", + ) + location = response.headers["location"] + parsed = urlparse(location) + assert ( + parsed.scheme + "://" + parsed.netloc + parsed.path + == "https://auth.example.com/authorize" + ) + params = parse_qs(parsed.query) + assert params["client_id"] == ["my-client"] + assert params["response_type"] == ["code"] + assert params["scope"] == ["openid email"] + assert params["redirect_uri"] == ["https://app.example.com/callback"] + assert ( + oauth_decode_state(params["state"][0], fallback="") + == "https://app.example.com/dashboard" + ) + + +class TestResolveProviderUrls: + def _discovery(self, *, userinfo=True): + doc = { + "authorization_endpoint": "https://auth.example.com/authorize", + "token_endpoint": "https://auth.example.com/token", + } + if userinfo: + doc["userinfo_endpoint"] = "https://auth.example.com/userinfo" + return doc + + @pytest.mark.anyio + async def test_returns_all_endpoints(self): + mock_resp = MagicMock() + mock_resp.raise_for_status = MagicMock() + mock_resp.json.return_value = self._discovery() + cm, mock_client = _make_async_client_mock(get_return=mock_resp) + + with patch("fastapi_toolsets.security.oauth._discovery_cache", {}): + with patch("httpx.AsyncClient", return_value=cm): + auth_url, token_url, userinfo_url = await oauth_resolve_provider_urls( + "https://auth.example.com/.well-known/openid-configuration" + ) + + assert auth_url == "https://auth.example.com/authorize" + assert token_url == "https://auth.example.com/token" + assert userinfo_url == "https://auth.example.com/userinfo" + + @pytest.mark.anyio + async def test_userinfo_url_none_when_absent(self): + mock_resp = MagicMock() + mock_resp.raise_for_status = MagicMock() + mock_resp.json.return_value = self._discovery(userinfo=False) + cm, mock_client = _make_async_client_mock(get_return=mock_resp) + + with patch("fastapi_toolsets.security.oauth._discovery_cache", {}): + with patch("httpx.AsyncClient", return_value=cm): + _, _, userinfo_url = await oauth_resolve_provider_urls( + "https://auth.example.com/.well-known/openid-configuration" + ) + + assert userinfo_url is None + + @pytest.mark.anyio + async def test_caches_discovery_document(self): + mock_resp = MagicMock() + mock_resp.raise_for_status = MagicMock() + mock_resp.json.return_value = self._discovery() + cm, mock_client = _make_async_client_mock(get_return=mock_resp) + + url = "https://auth.example.com/.well-known/openid-configuration" + with patch("fastapi_toolsets.security.oauth._discovery_cache", {}): + with patch("httpx.AsyncClient", return_value=cm): + await oauth_resolve_provider_urls(url) + await oauth_resolve_provider_urls(url) + + assert mock_client.get.call_count == 1 + + +class TestFetchUserinfo: + @pytest.mark.anyio + async def test_returns_userinfo_payload(self): + token_resp = MagicMock() + token_resp.raise_for_status = MagicMock() + token_resp.json.return_value = {"access_token": "tok123"} + + userinfo_resp = MagicMock() + userinfo_resp.raise_for_status = MagicMock() + userinfo_resp.json.return_value = { + "sub": "user-1", + "email": "alice@example.com", + } + + cm, mock_client = _make_async_client_mock( + post_return=token_resp, get_return=userinfo_resp + ) + + with patch("httpx.AsyncClient", return_value=cm): + result = await oauth_fetch_userinfo( + token_url="https://auth.example.com/token", + userinfo_url="https://auth.example.com/userinfo", + code="authcode123", + client_id="client-id", + client_secret="client-secret", + redirect_uri="https://app.example.com/callback", + ) + + assert result == {"sub": "user-1", "email": "alice@example.com"} + + @pytest.mark.anyio + async def test_posts_correct_token_request_and_uses_bearer(self): + token_resp = MagicMock() + token_resp.raise_for_status = MagicMock() + token_resp.json.return_value = {"access_token": "tok123"} + + userinfo_resp = MagicMock() + userinfo_resp.raise_for_status = MagicMock() + userinfo_resp.json.return_value = {} + + cm, mock_client = _make_async_client_mock( + post_return=token_resp, get_return=userinfo_resp + ) + + with patch("httpx.AsyncClient", return_value=cm): + await oauth_fetch_userinfo( + token_url="https://auth.example.com/token", + userinfo_url="https://auth.example.com/userinfo", + code="authcode123", + client_id="my-client", + client_secret="my-secret", + redirect_uri="https://app.example.com/callback", + ) + + mock_client.post.assert_called_once_with( + "https://auth.example.com/token", + data={ + "grant_type": "authorization_code", + "code": "authcode123", + "client_id": "my-client", + "client_secret": "my-secret", + "redirect_uri": "https://app.example.com/callback", + }, + headers={"Accept": "application/json"}, + ) + mock_client.get.assert_called_once_with( + "https://auth.example.com/userinfo", + headers={"Authorization": "Bearer tok123"}, + )