feat(security): add oauth helpers

This commit is contained in:
2026-03-07 10:29:20 -05:00
parent f13e61c5a7
commit 233362dd35
3 changed files with 344 additions and 9 deletions

View File

@@ -1,5 +1,8 @@
"""Tests for fastapi_toolsets.security."""
from unittest.mock import AsyncMock, MagicMock, patch
from urllib.parse import parse_qs, urlparse
import pytest
from fastapi import FastAPI, Security
from fastapi.testclient import TestClient
@@ -11,6 +14,11 @@ from fastapi_toolsets.security import (
BearerTokenAuth,
CookieAuth,
MultiAuth,
oauth_build_authorization_redirect,
oauth_decode_state,
oauth_encode_state,
oauth_fetch_userinfo,
oauth_resolve_provider_urls,
)
@@ -129,8 +137,6 @@ class TestBearerTokenAuth:
response = client.get("/me", headers={"Authorization": "Bearer org_abc123"})
assert response.status_code == 401
# --- extract() ---
@pytest.mark.anyio
async def test_extract_no_header(self):
from starlette.requests import Request
@@ -196,8 +202,6 @@ class TestBearerTokenAuth:
request = Request(scope)
assert await bearer.extract(request) is None
# --- generate_token() ---
def test_generate_token_no_prefix(self):
bearer = BearerTokenAuth(simple_validator)
token = bearer.generate_token()
@@ -962,3 +966,209 @@ class TestAuthSource:
response = client.get("/me", headers={"X-Token": "s3cr3t"})
assert response.status_code == 200
assert response.json() == {"token": "s3cr3t"}
def _make_async_client_mock(get_return=None, post_return=None):
"""Return a patched httpx.AsyncClient context-manager mock."""
mock_client = AsyncMock()
if get_return is not None:
mock_client.get.return_value = get_return
if post_return is not None:
mock_client.post.return_value = post_return
cm = MagicMock()
cm.__aenter__ = AsyncMock(return_value=mock_client)
cm.__aexit__ = AsyncMock(return_value=None)
return cm, mock_client
class TestEncodeDecodeOAuthState:
def test_encode_returns_base64url_string(self):
result = oauth_encode_state("https://example.com/dashboard")
assert isinstance(result, str)
assert "+" not in result
assert "/" not in result
def test_round_trip(self):
url = "https://example.com/after-login?next=/home"
assert oauth_decode_state(oauth_encode_state(url), fallback="/") == url
def test_decode_none_returns_fallback(self):
assert oauth_decode_state(None, fallback="/home") == "/home"
def test_decode_null_string_returns_fallback(self):
assert oauth_decode_state("null", fallback="/home") == "/home"
def test_decode_invalid_base64_returns_fallback(self):
assert oauth_decode_state("!!!notbase64!!!", fallback="/home") == "/home"
def test_decode_handles_missing_padding(self):
url = "https://example.com/x"
encoded = oauth_encode_state(url).rstrip("=")
assert oauth_decode_state(encoded, fallback="/") == url
class TestBuildAuthorizationRedirect:
def test_returns_redirect_response(self):
from fastapi.responses import RedirectResponse
response = oauth_build_authorization_redirect(
"https://auth.example.com/authorize",
client_id="my-client",
scopes="openid email",
redirect_uri="https://app.example.com/callback",
destination="https://app.example.com/dashboard",
)
assert isinstance(response, RedirectResponse)
def test_redirect_location_contains_all_params(self):
response = oauth_build_authorization_redirect(
"https://auth.example.com/authorize",
client_id="my-client",
scopes="openid email",
redirect_uri="https://app.example.com/callback",
destination="https://app.example.com/dashboard",
)
location = response.headers["location"]
parsed = urlparse(location)
assert (
parsed.scheme + "://" + parsed.netloc + parsed.path
== "https://auth.example.com/authorize"
)
params = parse_qs(parsed.query)
assert params["client_id"] == ["my-client"]
assert params["response_type"] == ["code"]
assert params["scope"] == ["openid email"]
assert params["redirect_uri"] == ["https://app.example.com/callback"]
assert (
oauth_decode_state(params["state"][0], fallback="")
== "https://app.example.com/dashboard"
)
class TestResolveProviderUrls:
def _discovery(self, *, userinfo=True):
doc = {
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token",
}
if userinfo:
doc["userinfo_endpoint"] = "https://auth.example.com/userinfo"
return doc
@pytest.mark.anyio
async def test_returns_all_endpoints(self):
mock_resp = MagicMock()
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = self._discovery()
cm, mock_client = _make_async_client_mock(get_return=mock_resp)
with patch("fastapi_toolsets.security.oauth._discovery_cache", {}):
with patch("httpx.AsyncClient", return_value=cm):
auth_url, token_url, userinfo_url = await oauth_resolve_provider_urls(
"https://auth.example.com/.well-known/openid-configuration"
)
assert auth_url == "https://auth.example.com/authorize"
assert token_url == "https://auth.example.com/token"
assert userinfo_url == "https://auth.example.com/userinfo"
@pytest.mark.anyio
async def test_userinfo_url_none_when_absent(self):
mock_resp = MagicMock()
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = self._discovery(userinfo=False)
cm, mock_client = _make_async_client_mock(get_return=mock_resp)
with patch("fastapi_toolsets.security.oauth._discovery_cache", {}):
with patch("httpx.AsyncClient", return_value=cm):
_, _, userinfo_url = await oauth_resolve_provider_urls(
"https://auth.example.com/.well-known/openid-configuration"
)
assert userinfo_url is None
@pytest.mark.anyio
async def test_caches_discovery_document(self):
mock_resp = MagicMock()
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = self._discovery()
cm, mock_client = _make_async_client_mock(get_return=mock_resp)
url = "https://auth.example.com/.well-known/openid-configuration"
with patch("fastapi_toolsets.security.oauth._discovery_cache", {}):
with patch("httpx.AsyncClient", return_value=cm):
await oauth_resolve_provider_urls(url)
await oauth_resolve_provider_urls(url)
assert mock_client.get.call_count == 1
class TestFetchUserinfo:
@pytest.mark.anyio
async def test_returns_userinfo_payload(self):
token_resp = MagicMock()
token_resp.raise_for_status = MagicMock()
token_resp.json.return_value = {"access_token": "tok123"}
userinfo_resp = MagicMock()
userinfo_resp.raise_for_status = MagicMock()
userinfo_resp.json.return_value = {
"sub": "user-1",
"email": "alice@example.com",
}
cm, mock_client = _make_async_client_mock(
post_return=token_resp, get_return=userinfo_resp
)
with patch("httpx.AsyncClient", return_value=cm):
result = await oauth_fetch_userinfo(
token_url="https://auth.example.com/token",
userinfo_url="https://auth.example.com/userinfo",
code="authcode123",
client_id="client-id",
client_secret="client-secret",
redirect_uri="https://app.example.com/callback",
)
assert result == {"sub": "user-1", "email": "alice@example.com"}
@pytest.mark.anyio
async def test_posts_correct_token_request_and_uses_bearer(self):
token_resp = MagicMock()
token_resp.raise_for_status = MagicMock()
token_resp.json.return_value = {"access_token": "tok123"}
userinfo_resp = MagicMock()
userinfo_resp.raise_for_status = MagicMock()
userinfo_resp.json.return_value = {}
cm, mock_client = _make_async_client_mock(
post_return=token_resp, get_return=userinfo_resp
)
with patch("httpx.AsyncClient", return_value=cm):
await oauth_fetch_userinfo(
token_url="https://auth.example.com/token",
userinfo_url="https://auth.example.com/userinfo",
code="authcode123",
client_id="my-client",
client_secret="my-secret",
redirect_uri="https://app.example.com/callback",
)
mock_client.post.assert_called_once_with(
"https://auth.example.com/token",
data={
"grant_type": "authorization_code",
"code": "authcode123",
"client_id": "my-client",
"client_secret": "my-secret",
"redirect_uri": "https://app.example.com/callback",
},
headers={"Accept": "application/json"},
)
mock_client.get.assert_called_once_with(
"https://auth.example.com/userinfo",
headers={"Authorization": "Bearer tok123"},
)