feat: add security module

This commit is contained in:
2026-03-04 11:02:54 -05:00
parent 2c494fcd17
commit a17ea9b820
11 changed files with 1813 additions and 0 deletions

267
docs/module/security.md Normal file
View File

@@ -0,0 +1,267 @@
# Security
Composable authentication helpers for FastAPI that use `Security()` for OpenAPI documentation and accept user-provided validator functions with full type flexibility.
## Overview
The `security` module provides four auth source classes and a `MultiAuth` factory. Each class wraps a FastAPI security scheme for OpenAPI and accepts a validator function called as:
```python
await validator(credential, **kwargs)
```
where `kwargs` are the extra keyword arguments provided at instantiation (roles, permissions, enums, etc.). The validator returns the authenticated identity (e.g. a `User` model) which becomes the route dependency value.
```python
from fastapi import Security
from fastapi_toolsets.security import BearerTokenAuth
async def verify_token(token: str, *, role: str) -> User:
user = await db.get_by_token(token)
if not user or user.role != role:
raise UnauthorizedError()
return user
bearer_admin = BearerTokenAuth(verify_token, role="admin")
@app.get("/admin")
async def admin_route(user: User = Security(bearer_admin)):
return user
```
## Auth sources
### [`BearerTokenAuth`](../reference/security.md#fastapi_toolsets.security.BearerTokenAuth)
Reads the `Authorization: Bearer <token>` header. Wraps `HTTPBearer` for OpenAPI.
```python
from fastapi_toolsets.security import BearerTokenAuth
bearer = BearerTokenAuth(validator=verify_token)
@app.get("/me")
async def me(user: User = Security(bearer)):
return user
```
#### Token prefix
The optional `prefix` parameter restricts a `BearerTokenAuth` instance to tokens
that start with a given string. The prefix is **kept** in the value passed to the
validator — store and compare tokens with their prefix included.
This lets you deploy multiple `BearerTokenAuth` instances in the same application
and disambiguate them efficiently in `MultiAuth`:
```python
user_bearer = BearerTokenAuth(verify_user, prefix="user_") # matches "Bearer user_..."
org_bearer = BearerTokenAuth(verify_org, prefix="org_") # matches "Bearer org_..."
```
Use [`generate_token()`](#token-generation) to create correctly-prefixed tokens.
#### Token generation
`BearerTokenAuth.generate_token()` produces a secure random token ready to store
in your database and return to the client. If a prefix is configured it is
prepended automatically:
```python
bearer = BearerTokenAuth(verify_token, prefix="user_")
token = bearer.generate_token() # e.g. "user_Xk3mN..."
await db.store_token(user_id, token)
return {"access_token": token, "token_type": "bearer"}
```
The client sends `Authorization: Bearer user_Xk3mN...` and the validator receives
the full token (prefix included) to compare against the stored value.
### [`CookieAuth`](../reference/security.md#fastapi_toolsets.security.CookieAuth)
Reads a named cookie. Wraps `APIKeyCookie` for OpenAPI.
```python
from fastapi_toolsets.security import CookieAuth
cookie_auth = CookieAuth("session", validator=verify_session)
@app.get("/me")
async def me(user: User = Security(cookie_auth)):
return user
```
### [`OAuth2Auth`](../reference/security.md#fastapi_toolsets.security.OAuth2Auth)
Reads the `Authorization: Bearer <token>` header and registers the token endpoint
in OpenAPI via `OAuth2PasswordBearer`.
```python
from fastapi_toolsets.security import OAuth2Auth
oauth2_auth = OAuth2Auth(token_url="/token", validator=verify_token)
@app.get("/me")
async def me(user: User = Security(oauth2_auth)):
return user
```
### [`OpenIDAuth`](../reference/security.md#fastapi_toolsets.security.OpenIDAuth)
Reads the `Authorization: Bearer <token>` header and registers the OpenID Connect
discovery URL in OpenAPI via `OpenIdConnect`. Token validation is fully delegated
to your validator — use any OIDC / JWT library (`authlib`, `python-jose`, `PyJWT`).
```python
from fastapi_toolsets.security import OpenIDAuth
async def verify_google_token(token: str, *, audience: str) -> User:
payload = jwt.decode(token, google_public_keys, algorithms=["RS256"],
audience=audience)
return User(email=payload["email"], name=payload["name"])
google_auth = OpenIDAuth(
"https://accounts.google.com/.well-known/openid-configuration",
verify_google_token,
audience="my-client-id",
)
@app.get("/me")
async def me(user: User = Security(google_auth)):
return user
```
The discovery URL is used **only for OpenAPI documentation** — no requests are made
to it by this class. You are responsible for fetching and caching the provider's
public keys in your validator.
Multiple providers work naturally with `MultiAuth`:
```python
multi = MultiAuth(google_auth, github_auth)
@app.get("/data")
async def data(user: User = Security(multi)):
return user
```
## Typed validator kwargs
All auth classes forward extra instantiation keyword arguments to the validator.
Arguments can be any type — enums, strings, integers, etc. The validator returns
the authenticated identity, which FastAPI injects directly into the route handler.
```python
async def verify_token(token: str, *, role: Role, permission: str) -> User:
user = await decode_token(token)
if user.role != role or permission not in user.permissions:
raise UnauthorizedError()
return user
bearer = BearerTokenAuth(verify_token, role=Role.ADMIN, permission="billing:read")
```
Each auth instance is self-contained — create a separate instance per distinct
requirement instead of passing requirements through `Security(scopes=[...])`.
### Using `.require()` inline
If declaring a new top-level variable per role feels verbose, use `.require()` to
create a configured clone directly in the route decorator. The original instance
is not mutated:
```python
bearer = BearerTokenAuth(verify_token)
@app.get("/admin/stats")
async def admin_stats(user: User = Security(bearer.require(role=Role.ADMIN))):
return {"message": f"Hello admin {user.name}"}
@app.get("/profile")
async def profile(user: User = Security(bearer.require(role=Role.USER))):
return {"id": user.id, "name": user.name}
```
`.require()` kwargs are merged over existing ones — new values win on conflict.
The `prefix` (for `BearerTokenAuth`) and cookie name (for `CookieAuth`) are
always preserved.
`.require()` instances work transparently inside `MultiAuth`:
```python
multi = MultiAuth(
user_bearer.require(role=Role.USER),
org_bearer.require(role=Role.ADMIN),
)
```
## MultiAuth
[`MultiAuth`](../reference/security.md#fastapi_toolsets.security.MultiAuth) combines
multiple auth sources into a single callable. Sources are tried in order; the
first one that finds a credential wins.
```python
from fastapi_toolsets.security import MultiAuth
multi = MultiAuth(user_bearer, org_bearer, cookie_auth)
@app.get("/data")
async def data_route(user = Security(multi)):
return user
```
### Using `.require()` on MultiAuth
`MultiAuth` also supports `.require()`, which propagates the kwargs to every
source that implements it. Sources that do not (e.g. custom `AuthSource`
subclasses) are passed through unchanged:
```python
multi = MultiAuth(bearer, cookie)
@app.get("/admin")
async def admin(user: User = Security(multi.require(role=Role.ADMIN))):
return user
```
This is equivalent to calling `.require()` on each source individually:
```python
# These two are identical
multi.require(role=Role.ADMIN)
MultiAuth(
bearer.require(role=Role.ADMIN),
cookie.require(role=Role.ADMIN),
)
```
### Prefix-based dispatch
Because `extract()` is pure string matching (no I/O), prefix-based source
selection is essentially free. Only the matching source's validator (which may
involve DB or network I/O) is ever called:
```python
user_bearer = BearerTokenAuth(verify_user, prefix="user_")
org_bearer = BearerTokenAuth(verify_org, prefix="org_")
multi = MultiAuth(user_bearer, org_bearer)
# "Bearer user_alice" → only verify_user runs, receives "user_alice"
# "Bearer org_acme" → only verify_org runs, receives "org_acme"
```
Tokens are stored and compared **with their prefix** — use `generate_token()` on
each source to issue correctly-prefixed tokens:
```python
user_token = user_bearer.generate_token() # "user_..."
org_token = org_bearer.generate_token() # "org_..."
```
---
[:material-api: API Reference](../reference/security.md)

View File

@@ -0,0 +1,28 @@
# `security`
Here's the reference for the authentication helpers provided by the `security` module.
You can import them directly from `fastapi_toolsets.security`:
```python
from fastapi_toolsets.security import (
AuthSource,
BearerTokenAuth,
CookieAuth,
OAuth2Auth,
OpenIDAuth,
MultiAuth,
)
```
## ::: fastapi_toolsets.security.AuthSource
## ::: fastapi_toolsets.security.BearerTokenAuth
## ::: fastapi_toolsets.security.CookieAuth
## ::: fastapi_toolsets.security.OAuth2Auth
## ::: fastapi_toolsets.security.OpenIDAuth
## ::: fastapi_toolsets.security.MultiAuth

View File

@@ -0,0 +1,15 @@
"""Authentication helpers for FastAPI using Security()."""
from .abc import AuthSource
from .oauth import decode_oauth_state, encode_oauth_state
from .sources import APIKeyHeaderAuth, BearerTokenAuth, CookieAuth, MultiAuth
__all__ = [
"APIKeyHeaderAuth",
"AuthSource",
"BearerTokenAuth",
"CookieAuth",
"MultiAuth",
"decode_oauth_state",
"encode_oauth_state",
]

View File

@@ -0,0 +1,51 @@
"""Abstract base class for authentication sources."""
import inspect
from abc import ABC, abstractmethod
from typing import Any, Callable
from fastapi import Request
from fastapi.security import SecurityScopes
from fastapi_toolsets.exceptions import UnauthorizedError
async def _call_validator(
validator: Callable[..., Any], *args: Any, **kwargs: Any
) -> Any:
"""Call *validator* with *args* and *kwargs*, awaiting it if it is a coroutine function."""
if inspect.iscoroutinefunction(validator):
return await validator(*args, **kwargs)
return validator(*args, **kwargs)
class AuthSource(ABC):
"""Abstract base class for authentication sources."""
def __init__(self) -> None:
"""Set up the default FastAPI dependency signature."""
source = self
async def _call(
request: Request,
security_scopes: SecurityScopes, # noqa: ARG001
) -> Any:
credential = await source.extract(request)
if credential is None:
raise UnauthorizedError()
return await source.authenticate(credential)
self._call_fn: Callable[..., Any] = _call
self.__signature__ = inspect.signature(_call)
@abstractmethod
async def extract(self, request: Request) -> str | None:
"""Extract the raw credential from the request without validating."""
@abstractmethod
async def authenticate(self, credential: str) -> Any:
"""Validate a credential and return the authenticated identity."""
async def __call__(self, **kwargs: Any) -> Any:
"""FastAPI dependency dispatch."""
return await self._call_fn(**kwargs)

View File

@@ -0,0 +1,24 @@
"""OAuth 2.0 / OIDC helper utilities."""
import base64
def encode_oauth_state(url: str) -> str:
"""Base64url-encode a URL to embed as an OAuth ``state`` parameter."""
return base64.urlsafe_b64encode(url.encode()).decode()
def decode_oauth_state(state: str | None, *, fallback: str) -> str:
"""Decode a base64url OAuth ``state`` parameter.
Handles missing padding (some providers strip ``=``).
Returns *fallback* if *state* is absent, the literal string ``"null"``,
or cannot be decoded.
"""
if not state or state == "null":
return fallback
try:
padded = state + "=" * (4 - len(state) % 4)
return base64.urlsafe_b64decode(padded).decode()
except Exception:
return fallback

View File

@@ -0,0 +1,8 @@
"""Built-in authentication source implementations."""
from .header import APIKeyHeaderAuth
from .bearer import BearerTokenAuth
from .cookie import CookieAuth
from .multi import MultiAuth
__all__ = ["APIKeyHeaderAuth", "BearerTokenAuth", "CookieAuth", "MultiAuth"]

View File

@@ -0,0 +1,122 @@
"""Bearer token authentication source."""
import inspect
import secrets
from typing import Annotated, Any, Callable
from fastapi import Depends
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer, SecurityScopes
from fastapi_toolsets.exceptions import UnauthorizedError
from ..abc import AuthSource, _call_validator
class BearerTokenAuth(AuthSource):
"""Bearer token authentication source.
Wraps :class:`fastapi.security.HTTPBearer` for OpenAPI documentation.
The validator is called as ``await validator(credential, **kwargs)``
where ``kwargs`` are the extra keyword arguments provided at instantiation.
Args:
validator: Sync or async callable that receives the credential and any
extra keyword arguments, and returns the authenticated identity
(e.g. a ``User`` model). Should raise
:class:`~fastapi_toolsets.exceptions.UnauthorizedError` on failure.
prefix: Optional token prefix (e.g. ``"user_"``). If set, only tokens
whose value starts with this prefix are matched. The prefix is
**kept** in the value passed to the validator — store and compare
tokens with their prefix included. Use :meth:`generate_token` to
create correctly-prefixed tokens. This enables multiple
``BearerTokenAuth`` instances in the same app (e.g. ``"user_"``
for user tokens, ``"org_"`` for org tokens).
**kwargs: Extra keyword arguments forwarded to the validator on every
call (e.g. ``role=Role.ADMIN``).
"""
def __init__(
self,
validator: Callable[..., Any],
*,
prefix: str | None = None,
**kwargs: Any,
) -> None:
self._validator = validator
self._prefix = prefix
self._kwargs = kwargs
self._scheme = HTTPBearer(auto_error=False)
_scheme = self._scheme
_validator = validator
_kwargs = kwargs
_prefix = prefix
async def _call(
security_scopes: SecurityScopes, # noqa: ARG001
credentials: Annotated[
HTTPAuthorizationCredentials | None, Depends(_scheme)
] = None,
) -> Any:
if credentials is None:
raise UnauthorizedError()
token = credentials.credentials
if _prefix is not None and not token.startswith(_prefix):
raise UnauthorizedError()
return await _call_validator(_validator, token, **_kwargs)
self._call_fn = _call
self.__signature__ = inspect.signature(_call)
async def extract(self, request: Any) -> str | None:
"""Extract the raw credential from the request without validating.
Returns ``None`` if no ``Authorization: Bearer`` header is present,
the token is empty, or the token does not match the configured prefix.
The prefix is included in the returned value.
"""
auth = request.headers.get("Authorization", "")
if not auth.startswith("Bearer "):
return None
token = auth[7:]
if not token:
return None
if self._prefix is not None and not token.startswith(self._prefix):
return None
return token
async def authenticate(self, credential: str) -> Any:
"""Validate a credential and return the identity.
Calls ``await validator(credential, **kwargs)`` where ``kwargs`` are
the extra keyword arguments provided at instantiation.
"""
return await _call_validator(self._validator, credential, **self._kwargs)
def require(self, **kwargs: Any) -> "BearerTokenAuth":
"""Return a new instance with additional (or overriding) validator kwargs."""
return BearerTokenAuth(
self._validator,
prefix=self._prefix,
**{**self._kwargs, **kwargs},
)
def generate_token(self, nbytes: int = 32) -> str:
"""Generate a secure random token for this auth source.
Returns a URL-safe random token. If a prefix is configured it is
prepended — the returned value is what you store in your database
and return to the client as-is.
Args:
nbytes: Number of random bytes before base64 encoding. The
resulting string is ``ceil(nbytes * 4 / 3)`` characters
(43 chars for the default 32 bytes). Defaults to 32.
Returns:
A ready-to-use token string (e.g. ``"user_Xk3..."``).
"""
token = secrets.token_urlsafe(nbytes)
if self._prefix is not None:
return f"{self._prefix}{token}"
return token

View File

@@ -0,0 +1,142 @@
"""Cookie-based authentication source."""
import base64
import hashlib
import hmac
import inspect
import json
import time
from typing import Annotated, Any, Callable
from fastapi import Depends, Request, Response
from fastapi.security import APIKeyCookie, SecurityScopes
from fastapi_toolsets.exceptions import UnauthorizedError
from ..abc import AuthSource, _call_validator
class CookieAuth(AuthSource):
"""Cookie-based authentication source.
Wraps :class:`fastapi.security.APIKeyCookie` for OpenAPI documentation.
Optionally signs the cookie with HMAC-SHA256 to provide stateless, tamper-
proof sessions without any database entry.
Args:
name: Cookie name.
validator: Sync or async callable that receives the cookie value
(plain, after signature verification when ``secret_key`` is set)
and any extra keyword arguments, and returns the authenticated
identity.
secret_key: When provided, the cookie is HMAC-SHA256 signed.
:meth:`set_cookie` embeds an expiry and signs the payload;
:meth:`extract` verifies the signature and expiry before handing
the plain value to the validator. When ``None`` (default), the raw
cookie value is passed to the validator as-is.
ttl: Cookie lifetime in seconds (default 24 h). Only used when
``secret_key`` is set.
**kwargs: Extra keyword arguments forwarded to the validator on every
call (e.g. ``role=Role.ADMIN``).
"""
def __init__(
self,
name: str,
validator: Callable[..., Any],
*,
secret_key: str | None = None,
ttl: int = 86400,
**kwargs: Any,
) -> None:
self._name = name
self._validator = validator
self._secret_key = secret_key
self._ttl = ttl
self._kwargs = kwargs
self._scheme = APIKeyCookie(name=name, auto_error=False)
_scheme = self._scheme
_self = self
_kwargs = kwargs
async def _call(
security_scopes: SecurityScopes, # noqa: ARG001
value: Annotated[str | None, Depends(_scheme)] = None,
) -> Any:
if value is None:
raise UnauthorizedError()
plain = _self._verify(value)
return await _call_validator(_self._validator, plain, **_kwargs)
self._call_fn = _call
self.__signature__ = inspect.signature(_call)
def _hmac(self, data: str) -> str:
assert self._secret_key is not None
return hmac.new(
self._secret_key.encode(), data.encode(), hashlib.sha256
).hexdigest()
def _sign(self, value: str) -> str:
data = base64.urlsafe_b64encode(
json.dumps({"v": value, "exp": int(time.time()) + self._ttl}).encode()
).decode()
return f"{data}.{self._hmac(data)}"
def _verify(self, cookie_value: str) -> str:
"""Return the plain value, verifying HMAC + expiry when signed."""
if not self._secret_key:
return cookie_value
try:
data, sig = cookie_value.rsplit(".", 1)
except ValueError:
raise UnauthorizedError()
if not hmac.compare_digest(self._hmac(data), sig):
raise UnauthorizedError()
try:
payload = json.loads(base64.urlsafe_b64decode(data))
value: str = payload["v"]
exp: int = payload["exp"]
except Exception:
raise UnauthorizedError()
if exp < int(time.time()):
raise UnauthorizedError()
return value
async def extract(self, request: Request) -> str | None:
return request.cookies.get(self._name)
async def authenticate(self, credential: str) -> Any:
plain = self._verify(credential)
return await _call_validator(self._validator, plain, **self._kwargs)
def require(self, **kwargs: Any) -> "CookieAuth":
"""Return a new instance with additional (or overriding) validator kwargs."""
return CookieAuth(
self._name,
self._validator,
secret_key=self._secret_key,
ttl=self._ttl,
**{**self._kwargs, **kwargs},
)
def set_cookie(self, response: Response, value: str) -> None:
"""Attach the cookie to *response*, signing it when ``secret_key`` is set."""
cookie_value = self._sign(value) if self._secret_key else value
response.set_cookie(
self._name,
cookie_value,
httponly=True,
samesite="lax",
max_age=self._ttl,
)
def delete_cookie(self, response: Response) -> None:
"""Clear the session cookie (logout)."""
response.delete_cookie(self._name, httponly=True, samesite="lax")

View File

@@ -0,0 +1,71 @@
"""API key header authentication source."""
import inspect
from typing import Annotated, Any, Callable
from fastapi import Depends, Request
from fastapi.security import APIKeyHeader, SecurityScopes
from fastapi_toolsets.exceptions import UnauthorizedError
from ..abc import AuthSource, _call_validator
class APIKeyHeaderAuth(AuthSource):
"""API key header authentication source.
Wraps :class:`fastapi.security.APIKeyHeader` for OpenAPI documentation.
The validator is called as ``await validator(api_key, **kwargs)``
where ``kwargs`` are the extra keyword arguments provided at instantiation.
Args:
name: HTTP header name that carries the API key (e.g. ``"X-API-Key"``).
validator: Sync or async callable that receives the API key and any
extra keyword arguments, and returns the authenticated identity.
Should raise :class:`~fastapi_toolsets.exceptions.UnauthorizedError`
on failure.
**kwargs: Extra keyword arguments forwarded to the validator on every
call (e.g. ``role=Role.ADMIN``).
"""
def __init__(
self,
name: str,
validator: Callable[..., Any],
**kwargs: Any,
) -> None:
self._name = name
self._validator = validator
self._kwargs = kwargs
self._scheme = APIKeyHeader(name=name, auto_error=False)
_scheme = self._scheme
_validator = validator
_kwargs = kwargs
async def _call(
security_scopes: SecurityScopes, # noqa: ARG001
api_key: Annotated[str | None, Depends(_scheme)] = None,
) -> Any:
if api_key is None:
raise UnauthorizedError()
return await _call_validator(_validator, api_key, **_kwargs)
self._call_fn = _call
self.__signature__ = inspect.signature(_call)
async def extract(self, request: Request) -> str | None:
"""Extract the API key from the configured header."""
return request.headers.get(self._name) or None
async def authenticate(self, credential: str) -> Any:
"""Validate a credential and return the identity."""
return await _call_validator(self._validator, credential, **self._kwargs)
def require(self, **kwargs: Any) -> "APIKeyHeaderAuth":
"""Return a new instance with additional (or overriding) validator kwargs."""
return APIKeyHeaderAuth(
self._name,
self._validator,
**{**self._kwargs, **kwargs},
)

View File

@@ -0,0 +1,121 @@
"""MultiAuth: combine multiple authentication sources into a single callable."""
import inspect
from typing import Any, cast
from fastapi import Request
from fastapi.security import SecurityScopes
from fastapi_toolsets.exceptions import UnauthorizedError
from ..abc import AuthSource
class MultiAuth:
"""Combine multiple authentication sources into a single callable.
Sources are tried in order; the first one whose
:meth:`~AuthSource.extract` returns a non-``None`` credential wins.
Its :meth:`~AuthSource.authenticate` is called and the result returned.
If a credential is found but the validator raises, the exception propagates
immediately — the remaining sources are **not** tried. This prevents
silent fallthrough on invalid credentials.
If no source provides a credential,
:class:`~fastapi_toolsets.exceptions.UnauthorizedError` is raised.
The :meth:`~AuthSource.extract` method of each source performs only
string matching (no I/O), so prefix-based dispatch is essentially free.
Any :class:`~AuthSource` subclass — including user-defined ones — can be
passed as a source.
Args:
*sources: Auth source instances to try in order.
Example::
user_bearer = BearerTokenAuth(verify_user, prefix="user_")
org_bearer = BearerTokenAuth(verify_org, prefix="org_")
cookie = CookieAuth("session", verify_session)
multi = MultiAuth(user_bearer, org_bearer, cookie)
@app.get("/data")
async def data_route(user = Security(multi)):
return user
# Apply a shared requirement to all sources at once
@app.get("/admin")
async def admin_route(user = Security(multi.require(role=Role.ADMIN))):
return user
"""
def __init__(self, *sources: AuthSource) -> None:
self._sources = sources
_sources = sources
async def _call(
request: Request,
security_scopes: SecurityScopes, # noqa: ARG001
**kwargs: Any, # noqa: ARG001 — absorbs scheme values injected by FastAPI
) -> Any:
for source in _sources:
credential = await source.extract(request)
if credential is not None:
return await source.authenticate(credential)
raise UnauthorizedError()
self._call_fn = _call
# Build a merged signature that includes the security-scheme Depends()
# parameters from every source so FastAPI registers them in OpenAPI docs.
seen: set[str] = {"request", "security_scopes"}
merged: list[inspect.Parameter] = [
inspect.Parameter(
"request",
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=Request,
),
inspect.Parameter(
"security_scopes",
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=SecurityScopes,
),
]
for i, source in enumerate(sources):
for name, param in inspect.signature(source).parameters.items():
if name in seen:
continue
merged.append(param.replace(name=f"_s{i}_{name}"))
seen.add(name)
self.__signature__ = inspect.Signature(merged, return_annotation=Any)
async def __call__(self, **kwargs: Any) -> Any:
return await self._call_fn(**kwargs)
def require(self, **kwargs: Any) -> "MultiAuth":
"""Return a new :class:`MultiAuth` with kwargs forwarded to each source.
Calls ``.require(**kwargs)`` on every source that supports it. Sources
that do not implement ``.require()`` (e.g. custom :class:`~AuthSource`
subclasses) are passed through unchanged.
New kwargs are merged over each source's existing kwargs — new values
win on conflict::
multi = MultiAuth(bearer, cookie)
@app.get("/admin")
async def admin(user = Security(multi.require(role=Role.ADMIN))):
return user
"""
new_sources = tuple(
cast(Any, source).require(**kwargs)
if hasattr(source, "require")
else source
for source in self._sources
)
return MultiAuth(*new_sources)

964
tests/test_security.py Normal file
View File

@@ -0,0 +1,964 @@
"""Tests for fastapi_toolsets.security."""
import pytest
from fastapi import FastAPI, Security
from fastapi.testclient import TestClient
from fastapi_toolsets.exceptions import UnauthorizedError, init_exceptions_handlers
from fastapi_toolsets.security import (
APIKeyHeaderAuth,
AuthSource,
BearerTokenAuth,
CookieAuth,
MultiAuth,
)
def _app(*routes_setup_fns):
"""Build a minimal FastAPI test app with exception handlers."""
app = FastAPI()
init_exceptions_handlers(app)
for fn in routes_setup_fns:
fn(app)
return app
VALID_TOKEN = "secret"
VALID_COOKIE = "session123"
async def simple_validator(credential: str) -> dict:
if credential != VALID_TOKEN:
raise UnauthorizedError()
return {"user": "alice"}
async def role_validator(credential: str, *, role: str) -> dict:
if credential != VALID_TOKEN:
raise UnauthorizedError()
return {"user": "alice", "role": role}
async def cookie_validator(value: str) -> dict:
if value != VALID_COOKIE:
raise UnauthorizedError()
return {"session": value}
class TestBearerTokenAuth:
def test_valid_token_returns_identity(self):
bearer = BearerTokenAuth(simple_validator)
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(bearer)):
return user
client = TestClient(_app(setup))
response = client.get("/me", headers={"Authorization": f"Bearer {VALID_TOKEN}"})
assert response.status_code == 200
assert response.json() == {"user": "alice"}
def test_missing_header_returns_401(self):
bearer = BearerTokenAuth(simple_validator)
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(bearer)):
return user
client = TestClient(_app(setup))
response = client.get("/me")
assert response.status_code == 401
def test_invalid_token_returns_401(self):
bearer = BearerTokenAuth(simple_validator)
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(bearer)):
return user
client = TestClient(_app(setup))
response = client.get("/me", headers={"Authorization": "Bearer wrong"})
assert response.status_code == 401
def test_kwargs_forwarded_to_validator(self):
bearer = BearerTokenAuth(role_validator, role="admin")
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(bearer)):
return user
client = TestClient(_app(setup))
response = client.get("/me", headers={"Authorization": f"Bearer {VALID_TOKEN}"})
assert response.status_code == 200
assert response.json() == {"user": "alice", "role": "admin"}
def test_prefix_matching_passes_full_token(self):
"""Token with matching prefix: full token (with prefix) is passed to validator."""
received: list[str] = []
async def capturing_validator(credential: str) -> dict:
received.append(credential)
return {"user": "alice"}
bearer = BearerTokenAuth(capturing_validator, prefix="user_")
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(bearer)):
return user
client = TestClient(_app(setup))
response = client.get("/me", headers={"Authorization": "Bearer user_abc123"})
assert response.status_code == 200
# Prefix is kept — validator receives the full token as stored in DB
assert received == ["user_abc123"]
def test_prefix_mismatch_returns_401(self):
bearer = BearerTokenAuth(simple_validator, prefix="user_")
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(bearer)):
return user
client = TestClient(_app(setup))
response = client.get("/me", headers={"Authorization": "Bearer org_abc123"})
assert response.status_code == 401
# --- extract() ---
@pytest.mark.anyio
async def test_extract_no_header(self):
from starlette.requests import Request
bearer = BearerTokenAuth(simple_validator)
scope = {"type": "http", "method": "GET", "path": "/", "headers": []}
request = Request(scope)
assert await bearer.extract(request) is None
@pytest.mark.anyio
async def test_extract_empty_token(self):
from starlette.requests import Request
bearer = BearerTokenAuth(simple_validator)
scope = {
"type": "http",
"method": "GET",
"path": "/",
"headers": [(b"authorization", b"Bearer ")],
}
request = Request(scope)
assert await bearer.extract(request) is None
@pytest.mark.anyio
async def test_extract_no_prefix(self):
from starlette.requests import Request
bearer = BearerTokenAuth(simple_validator)
scope = {
"type": "http",
"method": "GET",
"path": "/",
"headers": [(b"authorization", b"Bearer mytoken")],
}
request = Request(scope)
assert await bearer.extract(request) == "mytoken"
@pytest.mark.anyio
async def test_extract_prefix_match(self):
from starlette.requests import Request
bearer = BearerTokenAuth(simple_validator, prefix="user_")
scope = {
"type": "http",
"method": "GET",
"path": "/",
"headers": [(b"authorization", b"Bearer user_abc")],
}
request = Request(scope)
assert await bearer.extract(request) == "user_abc"
@pytest.mark.anyio
async def test_extract_prefix_no_match(self):
from starlette.requests import Request
bearer = BearerTokenAuth(simple_validator, prefix="user_")
scope = {
"type": "http",
"method": "GET",
"path": "/",
"headers": [(b"authorization", b"Bearer org_abc")],
}
request = Request(scope)
assert await bearer.extract(request) is None
# --- generate_token() ---
def test_generate_token_no_prefix(self):
bearer = BearerTokenAuth(simple_validator)
token = bearer.generate_token()
assert isinstance(token, str)
assert len(token) > 0
def test_generate_token_with_prefix(self):
bearer = BearerTokenAuth(simple_validator, prefix="user_")
token = bearer.generate_token()
assert token.startswith("user_")
def test_generate_token_uniqueness(self):
bearer = BearerTokenAuth(simple_validator)
assert bearer.generate_token() != bearer.generate_token()
def test_generate_token_is_valid_credential(self):
"""A generated token (with prefix) is accepted by the same auth source."""
stored: list[str] = []
async def storing_validator(credential: str) -> dict:
stored.append(credential)
return {"token": credential}
bearer = BearerTokenAuth(storing_validator, prefix="user_")
token = bearer.generate_token()
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(bearer)):
return user
client = TestClient(_app(setup))
response = client.get("/me", headers={"Authorization": f"Bearer {token}"})
assert response.status_code == 200
assert stored == [token]
class TestCookieAuth:
def test_valid_cookie_returns_identity(self):
cookie_auth = CookieAuth("session", cookie_validator)
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(cookie_auth)):
return user
client = TestClient(_app(setup))
response = client.get("/me", cookies={"session": VALID_COOKIE})
assert response.status_code == 200
assert response.json() == {"session": VALID_COOKIE}
def test_missing_cookie_returns_401(self):
cookie_auth = CookieAuth("session", cookie_validator)
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(cookie_auth)):
return user
client = TestClient(_app(setup))
response = client.get("/me")
assert response.status_code == 401
def test_invalid_cookie_returns_401(self):
cookie_auth = CookieAuth("session", cookie_validator)
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(cookie_auth)):
return user
client = TestClient(_app(setup))
response = client.get("/me", cookies={"session": "wrong"})
assert response.status_code == 401
def test_kwargs_forwarded_to_validator(self):
async def session_validator(value: str, *, scope: str) -> dict:
if value != VALID_COOKIE:
raise UnauthorizedError()
return {"session": value, "scope": scope}
cookie_auth = CookieAuth("session", session_validator, scope="read")
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(cookie_auth)):
return user
client = TestClient(_app(setup))
response = client.get("/me", cookies={"session": VALID_COOKIE})
assert response.status_code == 200
assert response.json() == {"session": VALID_COOKIE, "scope": "read"}
@pytest.mark.anyio
async def test_extract_no_cookie(self):
from starlette.requests import Request
auth = CookieAuth("session", cookie_validator)
scope = {"type": "http", "method": "GET", "path": "/", "headers": []}
request = Request(scope)
assert await auth.extract(request) is None
@pytest.mark.anyio
async def test_extract_cookie_present(self):
from starlette.requests import Request
auth = CookieAuth("session", cookie_validator)
scope = {
"type": "http",
"method": "GET",
"path": "/",
"headers": [(b"cookie", b"session=abc")],
}
request = Request(scope)
assert await auth.extract(request) == "abc"
class TestAPIKeyHeaderAuth:
def test_valid_key_returns_identity(self):
auth = APIKeyHeaderAuth("X-API-Key", simple_validator)
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(auth)):
return user
client = TestClient(_app(setup))
response = client.get("/me", headers={"X-API-Key": VALID_TOKEN})
assert response.status_code == 200
assert response.json() == {"user": "alice"}
def test_missing_header_returns_401(self):
auth = APIKeyHeaderAuth("X-API-Key", simple_validator)
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(auth)):
return user
client = TestClient(_app(setup))
response = client.get("/me")
assert response.status_code == 401
def test_invalid_key_returns_401(self):
auth = APIKeyHeaderAuth("X-API-Key", simple_validator)
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(auth)):
return user
client = TestClient(_app(setup))
response = client.get("/me", headers={"X-API-Key": "wrong"})
assert response.status_code == 401
def test_kwargs_forwarded_to_validator(self):
auth = APIKeyHeaderAuth("X-API-Key", role_validator, role="admin")
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(auth)):
return user
client = TestClient(_app(setup))
response = client.get("/me", headers={"X-API-Key": VALID_TOKEN})
assert response.status_code == 200
assert response.json() == {"user": "alice", "role": "admin"}
def test_require_forwards_kwargs(self):
auth = APIKeyHeaderAuth("X-API-Key", role_validator)
def setup(app: FastAPI):
@app.get("/admin")
async def admin(user=Security(auth.require(role="admin"))):
return user
client = TestClient(_app(setup))
response = client.get("/admin", headers={"X-API-Key": VALID_TOKEN})
assert response.status_code == 200
assert response.json() == {"user": "alice", "role": "admin"}
def test_require_preserves_name(self):
auth = APIKeyHeaderAuth("X-API-Key", simple_validator)
derived = auth.require(role="admin")
assert derived._name == "X-API-Key"
def test_require_does_not_mutate_original(self):
auth = APIKeyHeaderAuth("X-API-Key", role_validator, role="user")
auth.require(role="admin")
assert auth._kwargs == {"role": "user"}
def test_in_multi_auth(self):
"""APIKeyHeaderAuth.authenticate() is exercised inside MultiAuth."""
bearer = BearerTokenAuth(simple_validator)
api_key = APIKeyHeaderAuth("X-API-Key", simple_validator)
multi = MultiAuth(bearer, api_key)
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(multi)):
return user
client = TestClient(_app(setup))
# No bearer → falls through to API key header
response = client.get("/me", headers={"X-API-Key": VALID_TOKEN})
assert response.status_code == 200
assert response.json() == {"user": "alice"}
def test_is_auth_source(self):
auth = APIKeyHeaderAuth("X-API-Key", simple_validator)
assert isinstance(auth, AuthSource)
@pytest.mark.anyio
async def test_extract_no_header(self):
from starlette.requests import Request
auth = APIKeyHeaderAuth("X-API-Key", simple_validator)
scope = {"type": "http", "method": "GET", "path": "/", "headers": []}
request = Request(scope)
assert await auth.extract(request) is None
@pytest.mark.anyio
async def test_extract_empty_header(self):
from starlette.requests import Request
auth = APIKeyHeaderAuth("X-API-Key", simple_validator)
scope = {
"type": "http",
"method": "GET",
"path": "/",
"headers": [(b"x-api-key", b"")],
}
request = Request(scope)
assert await auth.extract(request) is None
@pytest.mark.anyio
async def test_extract_key_present(self):
from starlette.requests import Request
auth = APIKeyHeaderAuth("X-API-Key", simple_validator)
scope = {
"type": "http",
"method": "GET",
"path": "/",
"headers": [(b"x-api-key", b"mykey")],
}
request = Request(scope)
assert await auth.extract(request) == "mykey"
class TestMultiAuth:
def test_first_source_matches(self):
bearer = BearerTokenAuth(simple_validator)
cookie = CookieAuth("session", cookie_validator)
multi = MultiAuth(bearer, cookie)
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(multi)):
return user
client = TestClient(_app(setup))
response = client.get("/me", headers={"Authorization": f"Bearer {VALID_TOKEN}"})
assert response.status_code == 200
assert response.json() == {"user": "alice"}
def test_second_source_matches_when_first_absent(self):
bearer = BearerTokenAuth(simple_validator)
cookie = CookieAuth("session", cookie_validator)
multi = MultiAuth(bearer, cookie)
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(multi)):
return user
client = TestClient(_app(setup))
# No Authorization header — falls through to cookie
response = client.get("/me", cookies={"session": VALID_COOKIE})
assert response.status_code == 200
assert response.json() == {"session": VALID_COOKIE}
def test_no_source_matches_returns_401(self):
bearer = BearerTokenAuth(simple_validator)
cookie = CookieAuth("session", cookie_validator)
multi = MultiAuth(bearer, cookie)
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(multi)):
return user
client = TestClient(_app(setup))
response = client.get("/me")
assert response.status_code == 401
def test_invalid_credential_does_not_fallthrough(self):
"""If a credential is found but invalid, the next source is NOT tried."""
second_called: list[bool] = []
async def tracking_validator(credential: str) -> dict:
second_called.append(True)
return {"from": "second"}
bearer = BearerTokenAuth(simple_validator) # raises on wrong token
cookie = CookieAuth("session", tracking_validator)
multi = MultiAuth(bearer, cookie)
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(multi)):
return user
client = TestClient(_app(setup))
# Bearer credential present but wrong — should NOT try cookie
response = client.get(
"/me",
headers={"Authorization": "Bearer wrong"},
cookies={"session": VALID_COOKIE},
)
assert response.status_code == 401
assert second_called == [] # cookie validator was never called
def test_prefix_routes_to_correct_source(self):
"""Prefix-based dispatch: only the matching source's validator is called."""
user_calls: list[str] = []
org_calls: list[str] = []
async def user_validator(credential: str) -> dict:
user_calls.append(credential)
return {"type": "user", "id": credential}
async def org_validator(credential: str) -> dict:
org_calls.append(credential)
return {"type": "org", "id": credential}
user_bearer = BearerTokenAuth(user_validator, prefix="user_")
org_bearer = BearerTokenAuth(org_validator, prefix="org_")
multi = MultiAuth(user_bearer, org_bearer)
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(multi)):
return user
client = TestClient(_app(setup))
response = client.get("/me", headers={"Authorization": "Bearer user_alice"})
assert response.status_code == 200
assert response.json() == {"type": "user", "id": "user_alice"}
assert user_calls == ["user_alice"]
assert org_calls == []
user_calls.clear()
response = client.get("/me", headers={"Authorization": "Bearer org_acme"})
assert response.status_code == 200
assert response.json() == {"type": "org", "id": "org_acme"}
assert user_calls == []
assert org_calls == ["org_acme"]
def test_require_returns_new_multi_auth(self):
from fastapi_toolsets.security.sources import MultiAuth as MultiAuthClass
bearer = BearerTokenAuth(role_validator)
multi = MultiAuth(bearer)
derived = multi.require(role="admin")
assert isinstance(derived, MultiAuthClass)
assert derived is not multi
def test_require_forwards_kwargs_to_sources(self):
"""multi.require() propagates to all sources that support it."""
bearer = BearerTokenAuth(role_validator)
multi = MultiAuth(bearer)
def setup(app: FastAPI):
@app.get("/admin")
async def admin(user=Security(multi.require(role="admin"))):
return user
client = TestClient(_app(setup))
response = client.get(
"/admin", headers={"Authorization": f"Bearer {VALID_TOKEN}"}
)
assert response.status_code == 200
assert response.json() == {"user": "alice", "role": "admin"}
def test_require_skips_sources_without_require(self):
"""Sources without require() are passed through unchanged."""
header_auth = _HeaderAuth(secret="s3cr3t")
multi = MultiAuth(header_auth)
derived = multi.require(role="admin")
assert derived._sources[0] is header_auth
def test_require_does_not_mutate_original(self):
bearer = BearerTokenAuth(role_validator, role="user")
multi = MultiAuth(bearer)
multi.require(role="admin")
assert bearer._kwargs == {"role": "user"}
def test_require_mixed_sources(self):
"""require() applies to sources with require(), skips those without."""
from typing import cast
bearer = BearerTokenAuth(role_validator)
header_auth = _HeaderAuth(secret="s3cr3t")
multi = MultiAuth(bearer, header_auth)
derived = multi.require(role="admin")
# bearer got require() applied, header_auth passed through
assert cast(BearerTokenAuth, derived._sources[0])._kwargs == {"role": "admin"}
assert derived._sources[1] is header_auth
class TestRequire:
def test_bearer_require_forwards_kwargs(self):
"""require() creates a new instance that passes merged kwargs to validator."""
bearer = BearerTokenAuth(role_validator)
def setup(app: FastAPI):
@app.get("/admin")
async def admin(user=Security(bearer.require(role="admin"))):
return user
client = TestClient(_app(setup))
response = client.get(
"/admin", headers={"Authorization": f"Bearer {VALID_TOKEN}"}
)
assert response.status_code == 200
assert response.json() == {"user": "alice", "role": "admin"}
def test_bearer_require_overrides_existing_kwarg(self):
"""require() kwargs override kwargs set at instantiation."""
bearer = BearerTokenAuth(role_validator, role="user")
def setup(app: FastAPI):
@app.get("/admin")
async def admin(user=Security(bearer.require(role="admin"))):
return user
client = TestClient(_app(setup))
response = client.get(
"/admin", headers={"Authorization": f"Bearer {VALID_TOKEN}"}
)
assert response.status_code == 200
assert response.json()["role"] == "admin"
def test_bearer_require_preserves_prefix(self):
"""require() keeps the prefix of the original instance."""
bearer = BearerTokenAuth(role_validator, prefix="user_")
derived = bearer.require(role="admin")
assert derived._prefix == "user_"
def test_bearer_require_does_not_mutate_original(self):
"""require() returns a new instance — original kwargs are unchanged."""
bearer = BearerTokenAuth(role_validator, role="user")
bearer.require(role="admin")
assert bearer._kwargs == {"role": "user"}
def test_cookie_require_forwards_kwargs(self):
async def scoped_validator(value: str, *, scope: str) -> dict:
if value != VALID_COOKIE:
raise UnauthorizedError()
return {"session": value, "scope": scope}
cookie = CookieAuth("session", scoped_validator)
def setup(app: FastAPI):
@app.get("/admin")
async def admin(user=Security(cookie.require(scope="admin"))):
return user
client = TestClient(_app(setup))
response = client.get("/admin", cookies={"session": VALID_COOKIE})
assert response.status_code == 200
assert response.json() == {"session": VALID_COOKIE, "scope": "admin"}
def test_cookie_require_preserves_name(self):
cookie = CookieAuth("session", cookie_validator)
derived = cookie.require(scope="admin")
assert derived._name == "session"
def test_bearer_require_in_multi_auth(self):
"""require() instances work seamlessly inside MultiAuth."""
PREFIXED_TOKEN = f"user_{VALID_TOKEN}"
async def prefixed_role_validator(credential: str, *, role: str) -> dict:
if credential != PREFIXED_TOKEN:
raise UnauthorizedError()
return {"user": "alice", "role": role}
bearer = BearerTokenAuth(prefixed_role_validator, prefix="user_")
multi = MultiAuth(bearer.require(role="admin"))
def setup(app: FastAPI):
@app.get("/admin")
async def admin(user=Security(multi)):
return user
client = TestClient(_app(setup))
response = client.get(
"/admin", headers={"Authorization": f"Bearer {PREFIXED_TOKEN}"}
)
assert response.status_code == 200
assert response.json() == {"user": "alice", "role": "admin"}
class TestSyncValidators:
"""Sync (non-async) validators — covers the sync path in _call_validator."""
def test_bearer_sync_validator(self):
def sync_validator(credential: str) -> dict:
if credential != VALID_TOKEN:
raise UnauthorizedError()
return {"user": "alice"}
bearer = BearerTokenAuth(sync_validator)
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(bearer)):
return user
client = TestClient(_app(setup))
response = client.get("/me", headers={"Authorization": f"Bearer {VALID_TOKEN}"})
assert response.status_code == 200
assert response.json() == {"user": "alice"}
def test_sync_validator_via_authenticate(self):
"""authenticate() with sync validator (MultiAuth path)."""
def sync_validator(credential: str) -> dict:
if credential != VALID_TOKEN:
raise UnauthorizedError()
return {"user": "alice"}
bearer = BearerTokenAuth(sync_validator)
multi = MultiAuth(bearer)
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(multi)):
return user
client = TestClient(_app(setup))
response = client.get("/me", headers={"Authorization": f"Bearer {VALID_TOKEN}"})
assert response.status_code == 200
assert response.json() == {"user": "alice"}
class TestCookieAuthSigned:
"""CookieAuth with HMAC-SHA256 signed cookies (secret_key path)."""
SECRET = "test-hmac-secret"
def test_valid_signed_cookie_via_set_cookie(self):
"""set_cookie signs the value; the signed cookie is verified on read."""
from fastapi import Response
auth = CookieAuth("session", cookie_validator, secret_key=self.SECRET)
def setup(app: FastAPI):
@app.get("/login")
async def login(response: Response):
auth.set_cookie(response, VALID_COOKIE)
return {"ok": True}
@app.get("/me")
async def me(user=Security(auth)):
return user
with TestClient(_app(setup)) as client:
client.get("/login")
response = client.get("/me")
assert response.status_code == 200
assert response.json() == {"session": VALID_COOKIE}
def test_tampered_signature_returns_401(self):
"""A cookie whose HMAC signature has been modified is rejected."""
import base64 as _b64
import json as _json
import time as _time
auth = CookieAuth("session", cookie_validator, secret_key=self.SECRET)
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(auth)):
return user
client = TestClient(_app(setup))
data = _b64.urlsafe_b64encode(
_json.dumps({"v": VALID_COOKIE, "exp": int(_time.time()) + 9999}).encode()
).decode()
response = client.get("/me", cookies={"session": f"{data}.invalidsig"})
assert response.status_code == 401
def test_expired_signed_cookie_returns_401(self):
"""A signed cookie past its expiry timestamp is rejected."""
import base64 as _b64
import hashlib as _hashlib
import hmac as _hmac
import json as _json
import time as _time
auth = CookieAuth("session", cookie_validator, secret_key=self.SECRET)
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(auth)):
return user
client = TestClient(_app(setup))
data = _b64.urlsafe_b64encode(
_json.dumps({"v": VALID_COOKIE, "exp": int(_time.time()) - 1}).encode()
).decode()
sig = _hmac.new(
self.SECRET.encode(), data.encode(), _hashlib.sha256
).hexdigest()
response = client.get("/me", cookies={"session": f"{data}.{sig}"})
assert response.status_code == 401
def test_invalid_json_payload_returns_401(self):
"""A signed cookie whose payload is not valid JSON is rejected."""
import base64 as _b64
import hashlib as _hashlib
import hmac as _hmac
auth = CookieAuth("session", cookie_validator, secret_key=self.SECRET)
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(auth)):
return user
client = TestClient(_app(setup))
data = _b64.urlsafe_b64encode(b"not-valid-json").decode()
sig = _hmac.new(
self.SECRET.encode(), data.encode(), _hashlib.sha256
).hexdigest()
response = client.get("/me", cookies={"session": f"{data}.{sig}"})
assert response.status_code == 401
def test_malformed_cookie_no_dot_returns_401(self):
"""A signed cookie without the dot separator is rejected."""
auth = CookieAuth("session", cookie_validator, secret_key=self.SECRET)
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(auth)):
return user
client = TestClient(_app(setup))
response = client.get("/me", cookies={"session": "nodothere"})
assert response.status_code == 401
def test_set_cookie_without_secret(self):
"""set_cookie without secret_key writes the raw value."""
from starlette.responses import Response as StarletteResponse
auth = CookieAuth("session", cookie_validator)
response = StarletteResponse()
auth.set_cookie(response, "rawvalue")
assert "session=rawvalue" in response.headers["set-cookie"]
def test_delete_cookie(self):
"""delete_cookie produces a Set-Cookie header that clears the session."""
from starlette.responses import Response as StarletteResponse
auth = CookieAuth("session", cookie_validator)
response = StarletteResponse()
auth.delete_cookie(response)
assert "session" in response.headers["set-cookie"]
# Minimal concrete subclass used only in tests below.
class _HeaderAuth(AuthSource):
"""Reads a custom X-Token header — no FastAPI security scheme."""
def __init__(self, secret: str) -> None:
super().__init__()
self._secret = secret
async def extract(self, request) -> str | None:
return request.headers.get("X-Token") or None
async def authenticate(self, credential: str) -> dict:
if credential != self._secret:
raise UnauthorizedError()
return {"token": credential}
class TestAuthSource:
def test_cannot_instantiate_abstract_class(self):
with pytest.raises(TypeError):
AuthSource()
def test_builtin_classes_are_auth_sources(self):
bearer = BearerTokenAuth(simple_validator)
cookie = CookieAuth("session", cookie_validator)
api_key = APIKeyHeaderAuth("X-API-Key", simple_validator)
assert isinstance(bearer, AuthSource)
assert isinstance(cookie, AuthSource)
assert isinstance(api_key, AuthSource)
def test_custom_source_standalone_valid(self):
"""Default __call__ wires extract + authenticate via Request injection."""
auth = _HeaderAuth(secret="s3cr3t")
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(auth)):
return user
client = TestClient(_app(setup))
response = client.get("/me", headers={"X-Token": "s3cr3t"})
assert response.status_code == 200
assert response.json() == {"token": "s3cr3t"}
def test_custom_source_standalone_missing_credential(self):
auth = _HeaderAuth(secret="s3cr3t")
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(auth)):
return user
client = TestClient(_app(setup))
response = client.get("/me") # no X-Token header
assert response.status_code == 401
def test_custom_source_standalone_invalid_credential(self):
auth = _HeaderAuth(secret="s3cr3t")
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(auth)):
return user
client = TestClient(_app(setup))
response = client.get("/me", headers={"X-Token": "wrong"})
assert response.status_code == 401
def test_custom_source_in_multi_auth(self):
"""Custom AuthSource works transparently inside MultiAuth."""
header_auth = _HeaderAuth(secret="s3cr3t")
bearer = BearerTokenAuth(simple_validator)
multi = MultiAuth(bearer, header_auth)
def setup(app: FastAPI):
@app.get("/me")
async def me(user=Security(multi)):
return user
client = TestClient(_app(setup))
# Bearer matches first
response = client.get("/me", headers={"Authorization": f"Bearer {VALID_TOKEN}"})
assert response.status_code == 200
assert response.json() == {"user": "alice"}
# No bearer → falls through to custom header source
response = client.get("/me", headers={"X-Token": "s3cr3t"})
assert response.status_code == 200
assert response.json() == {"token": "s3cr3t"}