mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-16 06:36:26 +02:00
feat(security): add oauth helpers
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
"""Tests for fastapi_toolsets.security."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Security
|
||||
from fastapi.testclient import TestClient
|
||||
@@ -11,6 +14,11 @@ from fastapi_toolsets.security import (
|
||||
BearerTokenAuth,
|
||||
CookieAuth,
|
||||
MultiAuth,
|
||||
oauth_build_authorization_redirect,
|
||||
oauth_decode_state,
|
||||
oauth_encode_state,
|
||||
oauth_fetch_userinfo,
|
||||
oauth_resolve_provider_urls,
|
||||
)
|
||||
|
||||
|
||||
@@ -129,8 +137,6 @@ class TestBearerTokenAuth:
|
||||
response = client.get("/me", headers={"Authorization": "Bearer org_abc123"})
|
||||
assert response.status_code == 401
|
||||
|
||||
# --- extract() ---
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_extract_no_header(self):
|
||||
from starlette.requests import Request
|
||||
@@ -196,8 +202,6 @@ class TestBearerTokenAuth:
|
||||
request = Request(scope)
|
||||
assert await bearer.extract(request) is None
|
||||
|
||||
# --- generate_token() ---
|
||||
|
||||
def test_generate_token_no_prefix(self):
|
||||
bearer = BearerTokenAuth(simple_validator)
|
||||
token = bearer.generate_token()
|
||||
@@ -962,3 +966,209 @@ class TestAuthSource:
|
||||
response = client.get("/me", headers={"X-Token": "s3cr3t"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"token": "s3cr3t"}
|
||||
|
||||
|
||||
def _make_async_client_mock(get_return=None, post_return=None):
|
||||
"""Return a patched httpx.AsyncClient context-manager mock."""
|
||||
mock_client = AsyncMock()
|
||||
if get_return is not None:
|
||||
mock_client.get.return_value = get_return
|
||||
if post_return is not None:
|
||||
mock_client.post.return_value = post_return
|
||||
cm = MagicMock()
|
||||
cm.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
cm.__aexit__ = AsyncMock(return_value=None)
|
||||
return cm, mock_client
|
||||
|
||||
|
||||
class TestEncodeDecodeOAuthState:
|
||||
def test_encode_returns_base64url_string(self):
|
||||
result = oauth_encode_state("https://example.com/dashboard")
|
||||
assert isinstance(result, str)
|
||||
assert "+" not in result
|
||||
assert "/" not in result
|
||||
|
||||
def test_round_trip(self):
|
||||
url = "https://example.com/after-login?next=/home"
|
||||
assert oauth_decode_state(oauth_encode_state(url), fallback="/") == url
|
||||
|
||||
def test_decode_none_returns_fallback(self):
|
||||
assert oauth_decode_state(None, fallback="/home") == "/home"
|
||||
|
||||
def test_decode_null_string_returns_fallback(self):
|
||||
assert oauth_decode_state("null", fallback="/home") == "/home"
|
||||
|
||||
def test_decode_invalid_base64_returns_fallback(self):
|
||||
assert oauth_decode_state("!!!notbase64!!!", fallback="/home") == "/home"
|
||||
|
||||
def test_decode_handles_missing_padding(self):
|
||||
url = "https://example.com/x"
|
||||
encoded = oauth_encode_state(url).rstrip("=")
|
||||
assert oauth_decode_state(encoded, fallback="/") == url
|
||||
|
||||
|
||||
class TestBuildAuthorizationRedirect:
|
||||
def test_returns_redirect_response(self):
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
response = oauth_build_authorization_redirect(
|
||||
"https://auth.example.com/authorize",
|
||||
client_id="my-client",
|
||||
scopes="openid email",
|
||||
redirect_uri="https://app.example.com/callback",
|
||||
destination="https://app.example.com/dashboard",
|
||||
)
|
||||
assert isinstance(response, RedirectResponse)
|
||||
|
||||
def test_redirect_location_contains_all_params(self):
|
||||
response = oauth_build_authorization_redirect(
|
||||
"https://auth.example.com/authorize",
|
||||
client_id="my-client",
|
||||
scopes="openid email",
|
||||
redirect_uri="https://app.example.com/callback",
|
||||
destination="https://app.example.com/dashboard",
|
||||
)
|
||||
location = response.headers["location"]
|
||||
parsed = urlparse(location)
|
||||
assert (
|
||||
parsed.scheme + "://" + parsed.netloc + parsed.path
|
||||
== "https://auth.example.com/authorize"
|
||||
)
|
||||
params = parse_qs(parsed.query)
|
||||
assert params["client_id"] == ["my-client"]
|
||||
assert params["response_type"] == ["code"]
|
||||
assert params["scope"] == ["openid email"]
|
||||
assert params["redirect_uri"] == ["https://app.example.com/callback"]
|
||||
assert (
|
||||
oauth_decode_state(params["state"][0], fallback="")
|
||||
== "https://app.example.com/dashboard"
|
||||
)
|
||||
|
||||
|
||||
class TestResolveProviderUrls:
|
||||
def _discovery(self, *, userinfo=True):
|
||||
doc = {
|
||||
"authorization_endpoint": "https://auth.example.com/authorize",
|
||||
"token_endpoint": "https://auth.example.com/token",
|
||||
}
|
||||
if userinfo:
|
||||
doc["userinfo_endpoint"] = "https://auth.example.com/userinfo"
|
||||
return doc
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_returns_all_endpoints(self):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
mock_resp.json.return_value = self._discovery()
|
||||
cm, mock_client = _make_async_client_mock(get_return=mock_resp)
|
||||
|
||||
with patch("fastapi_toolsets.security.oauth._discovery_cache", {}):
|
||||
with patch("httpx.AsyncClient", return_value=cm):
|
||||
auth_url, token_url, userinfo_url = await oauth_resolve_provider_urls(
|
||||
"https://auth.example.com/.well-known/openid-configuration"
|
||||
)
|
||||
|
||||
assert auth_url == "https://auth.example.com/authorize"
|
||||
assert token_url == "https://auth.example.com/token"
|
||||
assert userinfo_url == "https://auth.example.com/userinfo"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_userinfo_url_none_when_absent(self):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
mock_resp.json.return_value = self._discovery(userinfo=False)
|
||||
cm, mock_client = _make_async_client_mock(get_return=mock_resp)
|
||||
|
||||
with patch("fastapi_toolsets.security.oauth._discovery_cache", {}):
|
||||
with patch("httpx.AsyncClient", return_value=cm):
|
||||
_, _, userinfo_url = await oauth_resolve_provider_urls(
|
||||
"https://auth.example.com/.well-known/openid-configuration"
|
||||
)
|
||||
|
||||
assert userinfo_url is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_caches_discovery_document(self):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
mock_resp.json.return_value = self._discovery()
|
||||
cm, mock_client = _make_async_client_mock(get_return=mock_resp)
|
||||
|
||||
url = "https://auth.example.com/.well-known/openid-configuration"
|
||||
with patch("fastapi_toolsets.security.oauth._discovery_cache", {}):
|
||||
with patch("httpx.AsyncClient", return_value=cm):
|
||||
await oauth_resolve_provider_urls(url)
|
||||
await oauth_resolve_provider_urls(url)
|
||||
|
||||
assert mock_client.get.call_count == 1
|
||||
|
||||
|
||||
class TestFetchUserinfo:
|
||||
@pytest.mark.anyio
|
||||
async def test_returns_userinfo_payload(self):
|
||||
token_resp = MagicMock()
|
||||
token_resp.raise_for_status = MagicMock()
|
||||
token_resp.json.return_value = {"access_token": "tok123"}
|
||||
|
||||
userinfo_resp = MagicMock()
|
||||
userinfo_resp.raise_for_status = MagicMock()
|
||||
userinfo_resp.json.return_value = {
|
||||
"sub": "user-1",
|
||||
"email": "alice@example.com",
|
||||
}
|
||||
|
||||
cm, mock_client = _make_async_client_mock(
|
||||
post_return=token_resp, get_return=userinfo_resp
|
||||
)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=cm):
|
||||
result = await oauth_fetch_userinfo(
|
||||
token_url="https://auth.example.com/token",
|
||||
userinfo_url="https://auth.example.com/userinfo",
|
||||
code="authcode123",
|
||||
client_id="client-id",
|
||||
client_secret="client-secret",
|
||||
redirect_uri="https://app.example.com/callback",
|
||||
)
|
||||
|
||||
assert result == {"sub": "user-1", "email": "alice@example.com"}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_posts_correct_token_request_and_uses_bearer(self):
|
||||
token_resp = MagicMock()
|
||||
token_resp.raise_for_status = MagicMock()
|
||||
token_resp.json.return_value = {"access_token": "tok123"}
|
||||
|
||||
userinfo_resp = MagicMock()
|
||||
userinfo_resp.raise_for_status = MagicMock()
|
||||
userinfo_resp.json.return_value = {}
|
||||
|
||||
cm, mock_client = _make_async_client_mock(
|
||||
post_return=token_resp, get_return=userinfo_resp
|
||||
)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=cm):
|
||||
await oauth_fetch_userinfo(
|
||||
token_url="https://auth.example.com/token",
|
||||
userinfo_url="https://auth.example.com/userinfo",
|
||||
code="authcode123",
|
||||
client_id="my-client",
|
||||
client_secret="my-secret",
|
||||
redirect_uri="https://app.example.com/callback",
|
||||
)
|
||||
|
||||
mock_client.post.assert_called_once_with(
|
||||
"https://auth.example.com/token",
|
||||
data={
|
||||
"grant_type": "authorization_code",
|
||||
"code": "authcode123",
|
||||
"client_id": "my-client",
|
||||
"client_secret": "my-secret",
|
||||
"redirect_uri": "https://app.example.com/callback",
|
||||
},
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
mock_client.get.assert_called_once_with(
|
||||
"https://auth.example.com/userinfo",
|
||||
headers={"Authorization": "Bearer tok123"},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user