mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-16 23:02:29 +02:00
feat: add security module
This commit is contained in:
267
docs/module/security.md
Normal file
267
docs/module/security.md
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
# Security
|
||||||
|
|
||||||
|
Composable authentication helpers for FastAPI that use `Security()` for OpenAPI documentation and accept user-provided validator functions with full type flexibility.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `security` module provides four auth source classes and a `MultiAuth` factory. Each class wraps a FastAPI security scheme for OpenAPI and accepts a validator function called as:
|
||||||
|
|
||||||
|
```python
|
||||||
|
await validator(credential, **kwargs)
|
||||||
|
```
|
||||||
|
|
||||||
|
where `kwargs` are the extra keyword arguments provided at instantiation (roles, permissions, enums, etc.). The validator returns the authenticated identity (e.g. a `User` model) which becomes the route dependency value.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi import Security
|
||||||
|
from fastapi_toolsets.security import BearerTokenAuth
|
||||||
|
|
||||||
|
async def verify_token(token: str, *, role: str) -> User:
|
||||||
|
user = await db.get_by_token(token)
|
||||||
|
if not user or user.role != role:
|
||||||
|
raise UnauthorizedError()
|
||||||
|
return user
|
||||||
|
|
||||||
|
bearer_admin = BearerTokenAuth(verify_token, role="admin")
|
||||||
|
|
||||||
|
@app.get("/admin")
|
||||||
|
async def admin_route(user: User = Security(bearer_admin)):
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
## Auth sources
|
||||||
|
|
||||||
|
### [`BearerTokenAuth`](../reference/security.md#fastapi_toolsets.security.BearerTokenAuth)
|
||||||
|
|
||||||
|
Reads the `Authorization: Bearer <token>` header. Wraps `HTTPBearer` for OpenAPI.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.security import BearerTokenAuth
|
||||||
|
|
||||||
|
bearer = BearerTokenAuth(validator=verify_token)
|
||||||
|
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user: User = Security(bearer)):
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Token prefix
|
||||||
|
|
||||||
|
The optional `prefix` parameter restricts a `BearerTokenAuth` instance to tokens
|
||||||
|
that start with a given string. The prefix is **kept** in the value passed to the
|
||||||
|
validator — store and compare tokens with their prefix included.
|
||||||
|
|
||||||
|
This lets you deploy multiple `BearerTokenAuth` instances in the same application
|
||||||
|
and disambiguate them efficiently in `MultiAuth`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
user_bearer = BearerTokenAuth(verify_user, prefix="user_") # matches "Bearer user_..."
|
||||||
|
org_bearer = BearerTokenAuth(verify_org, prefix="org_") # matches "Bearer org_..."
|
||||||
|
```
|
||||||
|
|
||||||
|
Use [`generate_token()`](#token-generation) to create correctly-prefixed tokens.
|
||||||
|
|
||||||
|
#### Token generation
|
||||||
|
|
||||||
|
`BearerTokenAuth.generate_token()` produces a secure random token ready to store
|
||||||
|
in your database and return to the client. If a prefix is configured it is
|
||||||
|
prepended automatically:
|
||||||
|
|
||||||
|
```python
|
||||||
|
bearer = BearerTokenAuth(verify_token, prefix="user_")
|
||||||
|
|
||||||
|
token = bearer.generate_token() # e.g. "user_Xk3mN..."
|
||||||
|
await db.store_token(user_id, token)
|
||||||
|
return {"access_token": token, "token_type": "bearer"}
|
||||||
|
```
|
||||||
|
|
||||||
|
The client sends `Authorization: Bearer user_Xk3mN...` and the validator receives
|
||||||
|
the full token (prefix included) to compare against the stored value.
|
||||||
|
|
||||||
|
### [`CookieAuth`](../reference/security.md#fastapi_toolsets.security.CookieAuth)
|
||||||
|
|
||||||
|
Reads a named cookie. Wraps `APIKeyCookie` for OpenAPI.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.security import CookieAuth
|
||||||
|
|
||||||
|
cookie_auth = CookieAuth("session", validator=verify_session)
|
||||||
|
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user: User = Security(cookie_auth)):
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
### [`OAuth2Auth`](../reference/security.md#fastapi_toolsets.security.OAuth2Auth)
|
||||||
|
|
||||||
|
Reads the `Authorization: Bearer <token>` header and registers the token endpoint
|
||||||
|
in OpenAPI via `OAuth2PasswordBearer`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.security import OAuth2Auth
|
||||||
|
|
||||||
|
oauth2_auth = OAuth2Auth(token_url="/token", validator=verify_token)
|
||||||
|
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user: User = Security(oauth2_auth)):
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
### [`OpenIDAuth`](../reference/security.md#fastapi_toolsets.security.OpenIDAuth)
|
||||||
|
|
||||||
|
Reads the `Authorization: Bearer <token>` header and registers the OpenID Connect
|
||||||
|
discovery URL in OpenAPI via `OpenIdConnect`. Token validation is fully delegated
|
||||||
|
to your validator — use any OIDC / JWT library (`authlib`, `python-jose`, `PyJWT`).
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.security import OpenIDAuth
|
||||||
|
|
||||||
|
async def verify_google_token(token: str, *, audience: str) -> User:
|
||||||
|
payload = jwt.decode(token, google_public_keys, algorithms=["RS256"],
|
||||||
|
audience=audience)
|
||||||
|
return User(email=payload["email"], name=payload["name"])
|
||||||
|
|
||||||
|
google_auth = OpenIDAuth(
|
||||||
|
"https://accounts.google.com/.well-known/openid-configuration",
|
||||||
|
verify_google_token,
|
||||||
|
audience="my-client-id",
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user: User = Security(google_auth)):
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
The discovery URL is used **only for OpenAPI documentation** — no requests are made
|
||||||
|
to it by this class. You are responsible for fetching and caching the provider's
|
||||||
|
public keys in your validator.
|
||||||
|
|
||||||
|
Multiple providers work naturally with `MultiAuth`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
multi = MultiAuth(google_auth, github_auth)
|
||||||
|
|
||||||
|
@app.get("/data")
|
||||||
|
async def data(user: User = Security(multi)):
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
## Typed validator kwargs
|
||||||
|
|
||||||
|
All auth classes forward extra instantiation keyword arguments to the validator.
|
||||||
|
Arguments can be any type — enums, strings, integers, etc. The validator returns
|
||||||
|
the authenticated identity, which FastAPI injects directly into the route handler.
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def verify_token(token: str, *, role: Role, permission: str) -> User:
|
||||||
|
user = await decode_token(token)
|
||||||
|
if user.role != role or permission not in user.permissions:
|
||||||
|
raise UnauthorizedError()
|
||||||
|
return user
|
||||||
|
|
||||||
|
bearer = BearerTokenAuth(verify_token, role=Role.ADMIN, permission="billing:read")
|
||||||
|
```
|
||||||
|
|
||||||
|
Each auth instance is self-contained — create a separate instance per distinct
|
||||||
|
requirement instead of passing requirements through `Security(scopes=[...])`.
|
||||||
|
|
||||||
|
### Using `.require()` inline
|
||||||
|
|
||||||
|
If declaring a new top-level variable per role feels verbose, use `.require()` to
|
||||||
|
create a configured clone directly in the route decorator. The original instance
|
||||||
|
is not mutated:
|
||||||
|
|
||||||
|
```python
|
||||||
|
bearer = BearerTokenAuth(verify_token)
|
||||||
|
|
||||||
|
@app.get("/admin/stats")
|
||||||
|
async def admin_stats(user: User = Security(bearer.require(role=Role.ADMIN))):
|
||||||
|
return {"message": f"Hello admin {user.name}"}
|
||||||
|
|
||||||
|
@app.get("/profile")
|
||||||
|
async def profile(user: User = Security(bearer.require(role=Role.USER))):
|
||||||
|
return {"id": user.id, "name": user.name}
|
||||||
|
```
|
||||||
|
|
||||||
|
`.require()` kwargs are merged over existing ones — new values win on conflict.
|
||||||
|
The `prefix` (for `BearerTokenAuth`) and cookie name (for `CookieAuth`) are
|
||||||
|
always preserved.
|
||||||
|
|
||||||
|
`.require()` instances work transparently inside `MultiAuth`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
multi = MultiAuth(
|
||||||
|
user_bearer.require(role=Role.USER),
|
||||||
|
org_bearer.require(role=Role.ADMIN),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## MultiAuth
|
||||||
|
|
||||||
|
[`MultiAuth`](../reference/security.md#fastapi_toolsets.security.MultiAuth) combines
|
||||||
|
multiple auth sources into a single callable. Sources are tried in order; the
|
||||||
|
first one that finds a credential wins.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.security import MultiAuth
|
||||||
|
|
||||||
|
multi = MultiAuth(user_bearer, org_bearer, cookie_auth)
|
||||||
|
|
||||||
|
@app.get("/data")
|
||||||
|
async def data_route(user = Security(multi)):
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using `.require()` on MultiAuth
|
||||||
|
|
||||||
|
`MultiAuth` also supports `.require()`, which propagates the kwargs to every
|
||||||
|
source that implements it. Sources that do not (e.g. custom `AuthSource`
|
||||||
|
subclasses) are passed through unchanged:
|
||||||
|
|
||||||
|
```python
|
||||||
|
multi = MultiAuth(bearer, cookie)
|
||||||
|
|
||||||
|
@app.get("/admin")
|
||||||
|
async def admin(user: User = Security(multi.require(role=Role.ADMIN))):
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
This is equivalent to calling `.require()` on each source individually:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# These two are identical
|
||||||
|
multi.require(role=Role.ADMIN)
|
||||||
|
|
||||||
|
MultiAuth(
|
||||||
|
bearer.require(role=Role.ADMIN),
|
||||||
|
cookie.require(role=Role.ADMIN),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Prefix-based dispatch
|
||||||
|
|
||||||
|
Because `extract()` is pure string matching (no I/O), prefix-based source
|
||||||
|
selection is essentially free. Only the matching source's validator (which may
|
||||||
|
involve DB or network I/O) is ever called:
|
||||||
|
|
||||||
|
```python
|
||||||
|
user_bearer = BearerTokenAuth(verify_user, prefix="user_")
|
||||||
|
org_bearer = BearerTokenAuth(verify_org, prefix="org_")
|
||||||
|
|
||||||
|
multi = MultiAuth(user_bearer, org_bearer)
|
||||||
|
|
||||||
|
# "Bearer user_alice" → only verify_user runs, receives "user_alice"
|
||||||
|
# "Bearer org_acme" → only verify_org runs, receives "org_acme"
|
||||||
|
```
|
||||||
|
|
||||||
|
Tokens are stored and compared **with their prefix** — use `generate_token()` on
|
||||||
|
each source to issue correctly-prefixed tokens:
|
||||||
|
|
||||||
|
```python
|
||||||
|
user_token = user_bearer.generate_token() # "user_..."
|
||||||
|
org_token = org_bearer.generate_token() # "org_..."
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/security.md)
|
||||||
28
docs/reference/security.md
Normal file
28
docs/reference/security.md
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# `security`
|
||||||
|
|
||||||
|
Here's the reference for the authentication helpers provided by the `security` module.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.security`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.security import (
|
||||||
|
AuthSource,
|
||||||
|
BearerTokenAuth,
|
||||||
|
CookieAuth,
|
||||||
|
OAuth2Auth,
|
||||||
|
OpenIDAuth,
|
||||||
|
MultiAuth,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.security.AuthSource
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.security.BearerTokenAuth
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.security.CookieAuth
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.security.OAuth2Auth
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.security.OpenIDAuth
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.security.MultiAuth
|
||||||
15
src/fastapi_toolsets/security/__init__.py
Normal file
15
src/fastapi_toolsets/security/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""Authentication helpers for FastAPI using Security()."""
|
||||||
|
|
||||||
|
from .abc import AuthSource
|
||||||
|
from .oauth import decode_oauth_state, encode_oauth_state
|
||||||
|
from .sources import APIKeyHeaderAuth, BearerTokenAuth, CookieAuth, MultiAuth
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"APIKeyHeaderAuth",
|
||||||
|
"AuthSource",
|
||||||
|
"BearerTokenAuth",
|
||||||
|
"CookieAuth",
|
||||||
|
"MultiAuth",
|
||||||
|
"decode_oauth_state",
|
||||||
|
"encode_oauth_state",
|
||||||
|
]
|
||||||
51
src/fastapi_toolsets/security/abc.py
Normal file
51
src/fastapi_toolsets/security/abc.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
"""Abstract base class for authentication sources."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from fastapi.security import SecurityScopes
|
||||||
|
|
||||||
|
from fastapi_toolsets.exceptions import UnauthorizedError
|
||||||
|
|
||||||
|
|
||||||
|
async def _call_validator(
|
||||||
|
validator: Callable[..., Any], *args: Any, **kwargs: Any
|
||||||
|
) -> Any:
|
||||||
|
"""Call *validator* with *args* and *kwargs*, awaiting it if it is a coroutine function."""
|
||||||
|
if inspect.iscoroutinefunction(validator):
|
||||||
|
return await validator(*args, **kwargs)
|
||||||
|
return validator(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthSource(ABC):
|
||||||
|
"""Abstract base class for authentication sources."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Set up the default FastAPI dependency signature."""
|
||||||
|
source = self
|
||||||
|
|
||||||
|
async def _call(
|
||||||
|
request: Request,
|
||||||
|
security_scopes: SecurityScopes, # noqa: ARG001
|
||||||
|
) -> Any:
|
||||||
|
credential = await source.extract(request)
|
||||||
|
if credential is None:
|
||||||
|
raise UnauthorizedError()
|
||||||
|
return await source.authenticate(credential)
|
||||||
|
|
||||||
|
self._call_fn: Callable[..., Any] = _call
|
||||||
|
self.__signature__ = inspect.signature(_call)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def extract(self, request: Request) -> str | None:
|
||||||
|
"""Extract the raw credential from the request without validating."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def authenticate(self, credential: str) -> Any:
|
||||||
|
"""Validate a credential and return the authenticated identity."""
|
||||||
|
|
||||||
|
async def __call__(self, **kwargs: Any) -> Any:
|
||||||
|
"""FastAPI dependency dispatch."""
|
||||||
|
return await self._call_fn(**kwargs)
|
||||||
24
src/fastapi_toolsets/security/oauth.py
Normal file
24
src/fastapi_toolsets/security/oauth.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
"""OAuth 2.0 / OIDC helper utilities."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
|
||||||
|
|
||||||
|
def encode_oauth_state(url: str) -> str:
|
||||||
|
"""Base64url-encode a URL to embed as an OAuth ``state`` parameter."""
|
||||||
|
return base64.urlsafe_b64encode(url.encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def decode_oauth_state(state: str | None, *, fallback: str) -> str:
|
||||||
|
"""Decode a base64url OAuth ``state`` parameter.
|
||||||
|
|
||||||
|
Handles missing padding (some providers strip ``=``).
|
||||||
|
Returns *fallback* if *state* is absent, the literal string ``"null"``,
|
||||||
|
or cannot be decoded.
|
||||||
|
"""
|
||||||
|
if not state or state == "null":
|
||||||
|
return fallback
|
||||||
|
try:
|
||||||
|
padded = state + "=" * (4 - len(state) % 4)
|
||||||
|
return base64.urlsafe_b64decode(padded).decode()
|
||||||
|
except Exception:
|
||||||
|
return fallback
|
||||||
8
src/fastapi_toolsets/security/sources/__init__.py
Normal file
8
src/fastapi_toolsets/security/sources/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
"""Built-in authentication source implementations."""
|
||||||
|
|
||||||
|
from .header import APIKeyHeaderAuth
|
||||||
|
from .bearer import BearerTokenAuth
|
||||||
|
from .cookie import CookieAuth
|
||||||
|
from .multi import MultiAuth
|
||||||
|
|
||||||
|
__all__ = ["APIKeyHeaderAuth", "BearerTokenAuth", "CookieAuth", "MultiAuth"]
|
||||||
122
src/fastapi_toolsets/security/sources/bearer.py
Normal file
122
src/fastapi_toolsets/security/sources/bearer.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
"""Bearer token authentication source."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import secrets
|
||||||
|
from typing import Annotated, Any, Callable
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer, SecurityScopes
|
||||||
|
|
||||||
|
from fastapi_toolsets.exceptions import UnauthorizedError
|
||||||
|
|
||||||
|
from ..abc import AuthSource, _call_validator
|
||||||
|
|
||||||
|
|
||||||
|
class BearerTokenAuth(AuthSource):
|
||||||
|
"""Bearer token authentication source.
|
||||||
|
|
||||||
|
Wraps :class:`fastapi.security.HTTPBearer` for OpenAPI documentation.
|
||||||
|
The validator is called as ``await validator(credential, **kwargs)``
|
||||||
|
where ``kwargs`` are the extra keyword arguments provided at instantiation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
validator: Sync or async callable that receives the credential and any
|
||||||
|
extra keyword arguments, and returns the authenticated identity
|
||||||
|
(e.g. a ``User`` model). Should raise
|
||||||
|
:class:`~fastapi_toolsets.exceptions.UnauthorizedError` on failure.
|
||||||
|
prefix: Optional token prefix (e.g. ``"user_"``). If set, only tokens
|
||||||
|
whose value starts with this prefix are matched. The prefix is
|
||||||
|
**kept** in the value passed to the validator — store and compare
|
||||||
|
tokens with their prefix included. Use :meth:`generate_token` to
|
||||||
|
create correctly-prefixed tokens. This enables multiple
|
||||||
|
``BearerTokenAuth`` instances in the same app (e.g. ``"user_"``
|
||||||
|
for user tokens, ``"org_"`` for org tokens).
|
||||||
|
**kwargs: Extra keyword arguments forwarded to the validator on every
|
||||||
|
call (e.g. ``role=Role.ADMIN``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
validator: Callable[..., Any],
|
||||||
|
*,
|
||||||
|
prefix: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
self._validator = validator
|
||||||
|
self._prefix = prefix
|
||||||
|
self._kwargs = kwargs
|
||||||
|
self._scheme = HTTPBearer(auto_error=False)
|
||||||
|
|
||||||
|
_scheme = self._scheme
|
||||||
|
_validator = validator
|
||||||
|
_kwargs = kwargs
|
||||||
|
_prefix = prefix
|
||||||
|
|
||||||
|
async def _call(
|
||||||
|
security_scopes: SecurityScopes, # noqa: ARG001
|
||||||
|
credentials: Annotated[
|
||||||
|
HTTPAuthorizationCredentials | None, Depends(_scheme)
|
||||||
|
] = None,
|
||||||
|
) -> Any:
|
||||||
|
if credentials is None:
|
||||||
|
raise UnauthorizedError()
|
||||||
|
token = credentials.credentials
|
||||||
|
if _prefix is not None and not token.startswith(_prefix):
|
||||||
|
raise UnauthorizedError()
|
||||||
|
return await _call_validator(_validator, token, **_kwargs)
|
||||||
|
|
||||||
|
self._call_fn = _call
|
||||||
|
self.__signature__ = inspect.signature(_call)
|
||||||
|
|
||||||
|
async def extract(self, request: Any) -> str | None:
|
||||||
|
"""Extract the raw credential from the request without validating.
|
||||||
|
|
||||||
|
Returns ``None`` if no ``Authorization: Bearer`` header is present,
|
||||||
|
the token is empty, or the token does not match the configured prefix.
|
||||||
|
The prefix is included in the returned value.
|
||||||
|
"""
|
||||||
|
auth = request.headers.get("Authorization", "")
|
||||||
|
if not auth.startswith("Bearer "):
|
||||||
|
return None
|
||||||
|
token = auth[7:]
|
||||||
|
if not token:
|
||||||
|
return None
|
||||||
|
if self._prefix is not None and not token.startswith(self._prefix):
|
||||||
|
return None
|
||||||
|
return token
|
||||||
|
|
||||||
|
async def authenticate(self, credential: str) -> Any:
|
||||||
|
"""Validate a credential and return the identity.
|
||||||
|
|
||||||
|
Calls ``await validator(credential, **kwargs)`` where ``kwargs`` are
|
||||||
|
the extra keyword arguments provided at instantiation.
|
||||||
|
"""
|
||||||
|
return await _call_validator(self._validator, credential, **self._kwargs)
|
||||||
|
|
||||||
|
def require(self, **kwargs: Any) -> "BearerTokenAuth":
|
||||||
|
"""Return a new instance with additional (or overriding) validator kwargs."""
|
||||||
|
return BearerTokenAuth(
|
||||||
|
self._validator,
|
||||||
|
prefix=self._prefix,
|
||||||
|
**{**self._kwargs, **kwargs},
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_token(self, nbytes: int = 32) -> str:
|
||||||
|
"""Generate a secure random token for this auth source.
|
||||||
|
|
||||||
|
Returns a URL-safe random token. If a prefix is configured it is
|
||||||
|
prepended — the returned value is what you store in your database
|
||||||
|
and return to the client as-is.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nbytes: Number of random bytes before base64 encoding. The
|
||||||
|
resulting string is ``ceil(nbytes * 4 / 3)`` characters
|
||||||
|
(43 chars for the default 32 bytes). Defaults to 32.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A ready-to-use token string (e.g. ``"user_Xk3..."``).
|
||||||
|
"""
|
||||||
|
token = secrets.token_urlsafe(nbytes)
|
||||||
|
if self._prefix is not None:
|
||||||
|
return f"{self._prefix}{token}"
|
||||||
|
return token
|
||||||
142
src/fastapi_toolsets/security/sources/cookie.py
Normal file
142
src/fastapi_toolsets/security/sources/cookie.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
"""Cookie-based authentication source."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import Annotated, Any, Callable
|
||||||
|
|
||||||
|
from fastapi import Depends, Request, Response
|
||||||
|
from fastapi.security import APIKeyCookie, SecurityScopes
|
||||||
|
|
||||||
|
from fastapi_toolsets.exceptions import UnauthorizedError
|
||||||
|
|
||||||
|
from ..abc import AuthSource, _call_validator
|
||||||
|
|
||||||
|
|
||||||
|
class CookieAuth(AuthSource):
|
||||||
|
"""Cookie-based authentication source.
|
||||||
|
|
||||||
|
Wraps :class:`fastapi.security.APIKeyCookie` for OpenAPI documentation.
|
||||||
|
Optionally signs the cookie with HMAC-SHA256 to provide stateless, tamper-
|
||||||
|
proof sessions without any database entry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Cookie name.
|
||||||
|
validator: Sync or async callable that receives the cookie value
|
||||||
|
(plain, after signature verification when ``secret_key`` is set)
|
||||||
|
and any extra keyword arguments, and returns the authenticated
|
||||||
|
identity.
|
||||||
|
secret_key: When provided, the cookie is HMAC-SHA256 signed.
|
||||||
|
:meth:`set_cookie` embeds an expiry and signs the payload;
|
||||||
|
:meth:`extract` verifies the signature and expiry before handing
|
||||||
|
the plain value to the validator. When ``None`` (default), the raw
|
||||||
|
cookie value is passed to the validator as-is.
|
||||||
|
ttl: Cookie lifetime in seconds (default 24 h). Only used when
|
||||||
|
``secret_key`` is set.
|
||||||
|
**kwargs: Extra keyword arguments forwarded to the validator on every
|
||||||
|
call (e.g. ``role=Role.ADMIN``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
validator: Callable[..., Any],
|
||||||
|
*,
|
||||||
|
secret_key: str | None = None,
|
||||||
|
ttl: int = 86400,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
self._name = name
|
||||||
|
self._validator = validator
|
||||||
|
self._secret_key = secret_key
|
||||||
|
self._ttl = ttl
|
||||||
|
self._kwargs = kwargs
|
||||||
|
self._scheme = APIKeyCookie(name=name, auto_error=False)
|
||||||
|
|
||||||
|
_scheme = self._scheme
|
||||||
|
_self = self
|
||||||
|
_kwargs = kwargs
|
||||||
|
|
||||||
|
async def _call(
|
||||||
|
security_scopes: SecurityScopes, # noqa: ARG001
|
||||||
|
value: Annotated[str | None, Depends(_scheme)] = None,
|
||||||
|
) -> Any:
|
||||||
|
if value is None:
|
||||||
|
raise UnauthorizedError()
|
||||||
|
plain = _self._verify(value)
|
||||||
|
return await _call_validator(_self._validator, plain, **_kwargs)
|
||||||
|
|
||||||
|
self._call_fn = _call
|
||||||
|
self.__signature__ = inspect.signature(_call)
|
||||||
|
|
||||||
|
def _hmac(self, data: str) -> str:
|
||||||
|
assert self._secret_key is not None
|
||||||
|
return hmac.new(
|
||||||
|
self._secret_key.encode(), data.encode(), hashlib.sha256
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
def _sign(self, value: str) -> str:
|
||||||
|
data = base64.urlsafe_b64encode(
|
||||||
|
json.dumps({"v": value, "exp": int(time.time()) + self._ttl}).encode()
|
||||||
|
).decode()
|
||||||
|
return f"{data}.{self._hmac(data)}"
|
||||||
|
|
||||||
|
def _verify(self, cookie_value: str) -> str:
|
||||||
|
"""Return the plain value, verifying HMAC + expiry when signed."""
|
||||||
|
if not self._secret_key:
|
||||||
|
return cookie_value
|
||||||
|
|
||||||
|
try:
|
||||||
|
data, sig = cookie_value.rsplit(".", 1)
|
||||||
|
except ValueError:
|
||||||
|
raise UnauthorizedError()
|
||||||
|
|
||||||
|
if not hmac.compare_digest(self._hmac(data), sig):
|
||||||
|
raise UnauthorizedError()
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = json.loads(base64.urlsafe_b64decode(data))
|
||||||
|
value: str = payload["v"]
|
||||||
|
exp: int = payload["exp"]
|
||||||
|
except Exception:
|
||||||
|
raise UnauthorizedError()
|
||||||
|
|
||||||
|
if exp < int(time.time()):
|
||||||
|
raise UnauthorizedError()
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
async def extract(self, request: Request) -> str | None:
|
||||||
|
return request.cookies.get(self._name)
|
||||||
|
|
||||||
|
async def authenticate(self, credential: str) -> Any:
|
||||||
|
plain = self._verify(credential)
|
||||||
|
return await _call_validator(self._validator, plain, **self._kwargs)
|
||||||
|
|
||||||
|
def require(self, **kwargs: Any) -> "CookieAuth":
|
||||||
|
"""Return a new instance with additional (or overriding) validator kwargs."""
|
||||||
|
return CookieAuth(
|
||||||
|
self._name,
|
||||||
|
self._validator,
|
||||||
|
secret_key=self._secret_key,
|
||||||
|
ttl=self._ttl,
|
||||||
|
**{**self._kwargs, **kwargs},
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_cookie(self, response: Response, value: str) -> None:
|
||||||
|
"""Attach the cookie to *response*, signing it when ``secret_key`` is set."""
|
||||||
|
cookie_value = self._sign(value) if self._secret_key else value
|
||||||
|
response.set_cookie(
|
||||||
|
self._name,
|
||||||
|
cookie_value,
|
||||||
|
httponly=True,
|
||||||
|
samesite="lax",
|
||||||
|
max_age=self._ttl,
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_cookie(self, response: Response) -> None:
|
||||||
|
"""Clear the session cookie (logout)."""
|
||||||
|
response.delete_cookie(self._name, httponly=True, samesite="lax")
|
||||||
71
src/fastapi_toolsets/security/sources/header.py
Normal file
71
src/fastapi_toolsets/security/sources/header.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
"""API key header authentication source."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from typing import Annotated, Any, Callable
|
||||||
|
|
||||||
|
from fastapi import Depends, Request
|
||||||
|
from fastapi.security import APIKeyHeader, SecurityScopes
|
||||||
|
|
||||||
|
from fastapi_toolsets.exceptions import UnauthorizedError
|
||||||
|
|
||||||
|
from ..abc import AuthSource, _call_validator
|
||||||
|
|
||||||
|
|
||||||
|
class APIKeyHeaderAuth(AuthSource):
|
||||||
|
"""API key header authentication source.
|
||||||
|
|
||||||
|
Wraps :class:`fastapi.security.APIKeyHeader` for OpenAPI documentation.
|
||||||
|
The validator is called as ``await validator(api_key, **kwargs)``
|
||||||
|
where ``kwargs`` are the extra keyword arguments provided at instantiation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: HTTP header name that carries the API key (e.g. ``"X-API-Key"``).
|
||||||
|
validator: Sync or async callable that receives the API key and any
|
||||||
|
extra keyword arguments, and returns the authenticated identity.
|
||||||
|
Should raise :class:`~fastapi_toolsets.exceptions.UnauthorizedError`
|
||||||
|
on failure.
|
||||||
|
**kwargs: Extra keyword arguments forwarded to the validator on every
|
||||||
|
call (e.g. ``role=Role.ADMIN``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
validator: Callable[..., Any],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
self._name = name
|
||||||
|
self._validator = validator
|
||||||
|
self._kwargs = kwargs
|
||||||
|
self._scheme = APIKeyHeader(name=name, auto_error=False)
|
||||||
|
|
||||||
|
_scheme = self._scheme
|
||||||
|
_validator = validator
|
||||||
|
_kwargs = kwargs
|
||||||
|
|
||||||
|
async def _call(
|
||||||
|
security_scopes: SecurityScopes, # noqa: ARG001
|
||||||
|
api_key: Annotated[str | None, Depends(_scheme)] = None,
|
||||||
|
) -> Any:
|
||||||
|
if api_key is None:
|
||||||
|
raise UnauthorizedError()
|
||||||
|
return await _call_validator(_validator, api_key, **_kwargs)
|
||||||
|
|
||||||
|
self._call_fn = _call
|
||||||
|
self.__signature__ = inspect.signature(_call)
|
||||||
|
|
||||||
|
async def extract(self, request: Request) -> str | None:
|
||||||
|
"""Extract the API key from the configured header."""
|
||||||
|
return request.headers.get(self._name) or None
|
||||||
|
|
||||||
|
async def authenticate(self, credential: str) -> Any:
|
||||||
|
"""Validate a credential and return the identity."""
|
||||||
|
return await _call_validator(self._validator, credential, **self._kwargs)
|
||||||
|
|
||||||
|
def require(self, **kwargs: Any) -> "APIKeyHeaderAuth":
|
||||||
|
"""Return a new instance with additional (or overriding) validator kwargs."""
|
||||||
|
return APIKeyHeaderAuth(
|
||||||
|
self._name,
|
||||||
|
self._validator,
|
||||||
|
**{**self._kwargs, **kwargs},
|
||||||
|
)
|
||||||
121
src/fastapi_toolsets/security/sources/multi.py
Normal file
121
src/fastapi_toolsets/security/sources/multi.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
"""MultiAuth: combine multiple authentication sources into a single callable."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from fastapi.security import SecurityScopes
|
||||||
|
|
||||||
|
from fastapi_toolsets.exceptions import UnauthorizedError
|
||||||
|
|
||||||
|
from ..abc import AuthSource
|
||||||
|
|
||||||
|
|
||||||
|
class MultiAuth:
|
||||||
|
"""Combine multiple authentication sources into a single callable.
|
||||||
|
|
||||||
|
Sources are tried in order; the first one whose
|
||||||
|
:meth:`~AuthSource.extract` returns a non-``None`` credential wins.
|
||||||
|
Its :meth:`~AuthSource.authenticate` is called and the result returned.
|
||||||
|
|
||||||
|
If a credential is found but the validator raises, the exception propagates
|
||||||
|
immediately — the remaining sources are **not** tried. This prevents
|
||||||
|
silent fallthrough on invalid credentials.
|
||||||
|
|
||||||
|
If no source provides a credential,
|
||||||
|
:class:`~fastapi_toolsets.exceptions.UnauthorizedError` is raised.
|
||||||
|
|
||||||
|
The :meth:`~AuthSource.extract` method of each source performs only
|
||||||
|
string matching (no I/O), so prefix-based dispatch is essentially free.
|
||||||
|
|
||||||
|
Any :class:`~AuthSource` subclass — including user-defined ones — can be
|
||||||
|
passed as a source.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*sources: Auth source instances to try in order.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
user_bearer = BearerTokenAuth(verify_user, prefix="user_")
|
||||||
|
org_bearer = BearerTokenAuth(verify_org, prefix="org_")
|
||||||
|
cookie = CookieAuth("session", verify_session)
|
||||||
|
|
||||||
|
multi = MultiAuth(user_bearer, org_bearer, cookie)
|
||||||
|
|
||||||
|
@app.get("/data")
|
||||||
|
async def data_route(user = Security(multi)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
# Apply a shared requirement to all sources at once
|
||||||
|
@app.get("/admin")
|
||||||
|
async def admin_route(user = Security(multi.require(role=Role.ADMIN))):
|
||||||
|
return user
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *sources: AuthSource) -> None:
|
||||||
|
self._sources = sources
|
||||||
|
|
||||||
|
_sources = sources
|
||||||
|
|
||||||
|
async def _call(
|
||||||
|
request: Request,
|
||||||
|
security_scopes: SecurityScopes, # noqa: ARG001
|
||||||
|
**kwargs: Any, # noqa: ARG001 — absorbs scheme values injected by FastAPI
|
||||||
|
) -> Any:
|
||||||
|
for source in _sources:
|
||||||
|
credential = await source.extract(request)
|
||||||
|
if credential is not None:
|
||||||
|
return await source.authenticate(credential)
|
||||||
|
raise UnauthorizedError()
|
||||||
|
|
||||||
|
self._call_fn = _call
|
||||||
|
|
||||||
|
# Build a merged signature that includes the security-scheme Depends()
|
||||||
|
# parameters from every source so FastAPI registers them in OpenAPI docs.
|
||||||
|
seen: set[str] = {"request", "security_scopes"}
|
||||||
|
merged: list[inspect.Parameter] = [
|
||||||
|
inspect.Parameter(
|
||||||
|
"request",
|
||||||
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||||
|
annotation=Request,
|
||||||
|
),
|
||||||
|
inspect.Parameter(
|
||||||
|
"security_scopes",
|
||||||
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||||
|
annotation=SecurityScopes,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for i, source in enumerate(sources):
|
||||||
|
for name, param in inspect.signature(source).parameters.items():
|
||||||
|
if name in seen:
|
||||||
|
continue
|
||||||
|
merged.append(param.replace(name=f"_s{i}_{name}"))
|
||||||
|
seen.add(name)
|
||||||
|
self.__signature__ = inspect.Signature(merged, return_annotation=Any)
|
||||||
|
|
||||||
|
async def __call__(self, **kwargs: Any) -> Any:
|
||||||
|
return await self._call_fn(**kwargs)
|
||||||
|
|
||||||
|
def require(self, **kwargs: Any) -> "MultiAuth":
|
||||||
|
"""Return a new :class:`MultiAuth` with kwargs forwarded to each source.
|
||||||
|
|
||||||
|
Calls ``.require(**kwargs)`` on every source that supports it. Sources
|
||||||
|
that do not implement ``.require()`` (e.g. custom :class:`~AuthSource`
|
||||||
|
subclasses) are passed through unchanged.
|
||||||
|
|
||||||
|
New kwargs are merged over each source's existing kwargs — new values
|
||||||
|
win on conflict::
|
||||||
|
|
||||||
|
multi = MultiAuth(bearer, cookie)
|
||||||
|
|
||||||
|
@app.get("/admin")
|
||||||
|
async def admin(user = Security(multi.require(role=Role.ADMIN))):
|
||||||
|
return user
|
||||||
|
"""
|
||||||
|
new_sources = tuple(
|
||||||
|
cast(Any, source).require(**kwargs)
|
||||||
|
if hasattr(source, "require")
|
||||||
|
else source
|
||||||
|
for source in self._sources
|
||||||
|
)
|
||||||
|
return MultiAuth(*new_sources)
|
||||||
964
tests/test_security.py
Normal file
964
tests/test_security.py
Normal file
@@ -0,0 +1,964 @@
|
|||||||
|
"""Tests for fastapi_toolsets.security."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI, Security
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from fastapi_toolsets.exceptions import UnauthorizedError, init_exceptions_handlers
|
||||||
|
from fastapi_toolsets.security import (
|
||||||
|
APIKeyHeaderAuth,
|
||||||
|
AuthSource,
|
||||||
|
BearerTokenAuth,
|
||||||
|
CookieAuth,
|
||||||
|
MultiAuth,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _app(*routes_setup_fns):
|
||||||
|
"""Build a minimal FastAPI test app with exception handlers."""
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
for fn in routes_setup_fns:
|
||||||
|
fn(app)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
VALID_TOKEN = "secret"
|
||||||
|
VALID_COOKIE = "session123"
|
||||||
|
|
||||||
|
|
||||||
|
async def simple_validator(credential: str) -> dict:
|
||||||
|
if credential != VALID_TOKEN:
|
||||||
|
raise UnauthorizedError()
|
||||||
|
return {"user": "alice"}
|
||||||
|
|
||||||
|
|
||||||
|
async def role_validator(credential: str, *, role: str) -> dict:
|
||||||
|
if credential != VALID_TOKEN:
|
||||||
|
raise UnauthorizedError()
|
||||||
|
return {"user": "alice", "role": role}
|
||||||
|
|
||||||
|
|
||||||
|
async def cookie_validator(value: str) -> dict:
|
||||||
|
if value != VALID_COOKIE:
|
||||||
|
raise UnauthorizedError()
|
||||||
|
return {"session": value}
|
||||||
|
|
||||||
|
|
||||||
|
class TestBearerTokenAuth:
|
||||||
|
def test_valid_token_returns_identity(self):
|
||||||
|
bearer = BearerTokenAuth(simple_validator)
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(bearer)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get("/me", headers={"Authorization": f"Bearer {VALID_TOKEN}"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"user": "alice"}
|
||||||
|
|
||||||
|
def test_missing_header_returns_401(self):
|
||||||
|
bearer = BearerTokenAuth(simple_validator)
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(bearer)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get("/me")
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_invalid_token_returns_401(self):
|
||||||
|
bearer = BearerTokenAuth(simple_validator)
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(bearer)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get("/me", headers={"Authorization": "Bearer wrong"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_kwargs_forwarded_to_validator(self):
|
||||||
|
bearer = BearerTokenAuth(role_validator, role="admin")
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(bearer)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get("/me", headers={"Authorization": f"Bearer {VALID_TOKEN}"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"user": "alice", "role": "admin"}
|
||||||
|
|
||||||
|
def test_prefix_matching_passes_full_token(self):
|
||||||
|
"""Token with matching prefix: full token (with prefix) is passed to validator."""
|
||||||
|
received: list[str] = []
|
||||||
|
|
||||||
|
async def capturing_validator(credential: str) -> dict:
|
||||||
|
received.append(credential)
|
||||||
|
return {"user": "alice"}
|
||||||
|
|
||||||
|
bearer = BearerTokenAuth(capturing_validator, prefix="user_")
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(bearer)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get("/me", headers={"Authorization": "Bearer user_abc123"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
# Prefix is kept — validator receives the full token as stored in DB
|
||||||
|
assert received == ["user_abc123"]
|
||||||
|
|
||||||
|
def test_prefix_mismatch_returns_401(self):
|
||||||
|
bearer = BearerTokenAuth(simple_validator, prefix="user_")
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(bearer)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get("/me", headers={"Authorization": "Bearer org_abc123"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
# --- extract() ---
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_extract_no_header(self):
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
bearer = BearerTokenAuth(simple_validator)
|
||||||
|
scope = {"type": "http", "method": "GET", "path": "/", "headers": []}
|
||||||
|
request = Request(scope)
|
||||||
|
assert await bearer.extract(request) is None
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_extract_empty_token(self):
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
bearer = BearerTokenAuth(simple_validator)
|
||||||
|
scope = {
|
||||||
|
"type": "http",
|
||||||
|
"method": "GET",
|
||||||
|
"path": "/",
|
||||||
|
"headers": [(b"authorization", b"Bearer ")],
|
||||||
|
}
|
||||||
|
request = Request(scope)
|
||||||
|
assert await bearer.extract(request) is None
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_extract_no_prefix(self):
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
bearer = BearerTokenAuth(simple_validator)
|
||||||
|
scope = {
|
||||||
|
"type": "http",
|
||||||
|
"method": "GET",
|
||||||
|
"path": "/",
|
||||||
|
"headers": [(b"authorization", b"Bearer mytoken")],
|
||||||
|
}
|
||||||
|
request = Request(scope)
|
||||||
|
assert await bearer.extract(request) == "mytoken"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_extract_prefix_match(self):
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
bearer = BearerTokenAuth(simple_validator, prefix="user_")
|
||||||
|
scope = {
|
||||||
|
"type": "http",
|
||||||
|
"method": "GET",
|
||||||
|
"path": "/",
|
||||||
|
"headers": [(b"authorization", b"Bearer user_abc")],
|
||||||
|
}
|
||||||
|
request = Request(scope)
|
||||||
|
assert await bearer.extract(request) == "user_abc"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_extract_prefix_no_match(self):
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
bearer = BearerTokenAuth(simple_validator, prefix="user_")
|
||||||
|
scope = {
|
||||||
|
"type": "http",
|
||||||
|
"method": "GET",
|
||||||
|
"path": "/",
|
||||||
|
"headers": [(b"authorization", b"Bearer org_abc")],
|
||||||
|
}
|
||||||
|
request = Request(scope)
|
||||||
|
assert await bearer.extract(request) is None
|
||||||
|
|
||||||
|
# --- generate_token() ---
|
||||||
|
|
||||||
|
def test_generate_token_no_prefix(self):
|
||||||
|
bearer = BearerTokenAuth(simple_validator)
|
||||||
|
token = bearer.generate_token()
|
||||||
|
assert isinstance(token, str)
|
||||||
|
assert len(token) > 0
|
||||||
|
|
||||||
|
def test_generate_token_with_prefix(self):
|
||||||
|
bearer = BearerTokenAuth(simple_validator, prefix="user_")
|
||||||
|
token = bearer.generate_token()
|
||||||
|
assert token.startswith("user_")
|
||||||
|
|
||||||
|
def test_generate_token_uniqueness(self):
|
||||||
|
bearer = BearerTokenAuth(simple_validator)
|
||||||
|
assert bearer.generate_token() != bearer.generate_token()
|
||||||
|
|
||||||
|
def test_generate_token_is_valid_credential(self):
|
||||||
|
"""A generated token (with prefix) is accepted by the same auth source."""
|
||||||
|
stored: list[str] = []
|
||||||
|
|
||||||
|
async def storing_validator(credential: str) -> dict:
|
||||||
|
stored.append(credential)
|
||||||
|
return {"token": credential}
|
||||||
|
|
||||||
|
bearer = BearerTokenAuth(storing_validator, prefix="user_")
|
||||||
|
token = bearer.generate_token()
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(bearer)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get("/me", headers={"Authorization": f"Bearer {token}"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert stored == [token]
|
||||||
|
|
||||||
|
|
||||||
|
class TestCookieAuth:
|
||||||
|
def test_valid_cookie_returns_identity(self):
|
||||||
|
cookie_auth = CookieAuth("session", cookie_validator)
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(cookie_auth)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get("/me", cookies={"session": VALID_COOKIE})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"session": VALID_COOKIE}
|
||||||
|
|
||||||
|
def test_missing_cookie_returns_401(self):
|
||||||
|
cookie_auth = CookieAuth("session", cookie_validator)
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(cookie_auth)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get("/me")
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_invalid_cookie_returns_401(self):
|
||||||
|
cookie_auth = CookieAuth("session", cookie_validator)
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(cookie_auth)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get("/me", cookies={"session": "wrong"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_kwargs_forwarded_to_validator(self):
|
||||||
|
async def session_validator(value: str, *, scope: str) -> dict:
|
||||||
|
if value != VALID_COOKIE:
|
||||||
|
raise UnauthorizedError()
|
||||||
|
return {"session": value, "scope": scope}
|
||||||
|
|
||||||
|
cookie_auth = CookieAuth("session", session_validator, scope="read")
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(cookie_auth)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get("/me", cookies={"session": VALID_COOKIE})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"session": VALID_COOKIE, "scope": "read"}
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_extract_no_cookie(self):
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
auth = CookieAuth("session", cookie_validator)
|
||||||
|
scope = {"type": "http", "method": "GET", "path": "/", "headers": []}
|
||||||
|
request = Request(scope)
|
||||||
|
assert await auth.extract(request) is None
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_extract_cookie_present(self):
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
auth = CookieAuth("session", cookie_validator)
|
||||||
|
scope = {
|
||||||
|
"type": "http",
|
||||||
|
"method": "GET",
|
||||||
|
"path": "/",
|
||||||
|
"headers": [(b"cookie", b"session=abc")],
|
||||||
|
}
|
||||||
|
request = Request(scope)
|
||||||
|
assert await auth.extract(request) == "abc"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAPIKeyHeaderAuth:
|
||||||
|
def test_valid_key_returns_identity(self):
|
||||||
|
auth = APIKeyHeaderAuth("X-API-Key", simple_validator)
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(auth)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get("/me", headers={"X-API-Key": VALID_TOKEN})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"user": "alice"}
|
||||||
|
|
||||||
|
def test_missing_header_returns_401(self):
|
||||||
|
auth = APIKeyHeaderAuth("X-API-Key", simple_validator)
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(auth)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get("/me")
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_invalid_key_returns_401(self):
|
||||||
|
auth = APIKeyHeaderAuth("X-API-Key", simple_validator)
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(auth)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get("/me", headers={"X-API-Key": "wrong"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_kwargs_forwarded_to_validator(self):
|
||||||
|
auth = APIKeyHeaderAuth("X-API-Key", role_validator, role="admin")
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(auth)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get("/me", headers={"X-API-Key": VALID_TOKEN})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"user": "alice", "role": "admin"}
|
||||||
|
|
||||||
|
def test_require_forwards_kwargs(self):
|
||||||
|
auth = APIKeyHeaderAuth("X-API-Key", role_validator)
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/admin")
|
||||||
|
async def admin(user=Security(auth.require(role="admin"))):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get("/admin", headers={"X-API-Key": VALID_TOKEN})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"user": "alice", "role": "admin"}
|
||||||
|
|
||||||
|
def test_require_preserves_name(self):
|
||||||
|
auth = APIKeyHeaderAuth("X-API-Key", simple_validator)
|
||||||
|
derived = auth.require(role="admin")
|
||||||
|
assert derived._name == "X-API-Key"
|
||||||
|
|
||||||
|
def test_require_does_not_mutate_original(self):
|
||||||
|
auth = APIKeyHeaderAuth("X-API-Key", role_validator, role="user")
|
||||||
|
auth.require(role="admin")
|
||||||
|
assert auth._kwargs == {"role": "user"}
|
||||||
|
|
||||||
|
def test_in_multi_auth(self):
|
||||||
|
"""APIKeyHeaderAuth.authenticate() is exercised inside MultiAuth."""
|
||||||
|
bearer = BearerTokenAuth(simple_validator)
|
||||||
|
api_key = APIKeyHeaderAuth("X-API-Key", simple_validator)
|
||||||
|
multi = MultiAuth(bearer, api_key)
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(multi)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
# No bearer → falls through to API key header
|
||||||
|
response = client.get("/me", headers={"X-API-Key": VALID_TOKEN})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"user": "alice"}
|
||||||
|
|
||||||
|
def test_is_auth_source(self):
|
||||||
|
auth = APIKeyHeaderAuth("X-API-Key", simple_validator)
|
||||||
|
assert isinstance(auth, AuthSource)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_extract_no_header(self):
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
auth = APIKeyHeaderAuth("X-API-Key", simple_validator)
|
||||||
|
scope = {"type": "http", "method": "GET", "path": "/", "headers": []}
|
||||||
|
request = Request(scope)
|
||||||
|
assert await auth.extract(request) is None
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_extract_empty_header(self):
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
auth = APIKeyHeaderAuth("X-API-Key", simple_validator)
|
||||||
|
scope = {
|
||||||
|
"type": "http",
|
||||||
|
"method": "GET",
|
||||||
|
"path": "/",
|
||||||
|
"headers": [(b"x-api-key", b"")],
|
||||||
|
}
|
||||||
|
request = Request(scope)
|
||||||
|
assert await auth.extract(request) is None
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_extract_key_present(self):
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
auth = APIKeyHeaderAuth("X-API-Key", simple_validator)
|
||||||
|
scope = {
|
||||||
|
"type": "http",
|
||||||
|
"method": "GET",
|
||||||
|
"path": "/",
|
||||||
|
"headers": [(b"x-api-key", b"mykey")],
|
||||||
|
}
|
||||||
|
request = Request(scope)
|
||||||
|
assert await auth.extract(request) == "mykey"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMultiAuth:
|
||||||
|
def test_first_source_matches(self):
|
||||||
|
bearer = BearerTokenAuth(simple_validator)
|
||||||
|
cookie = CookieAuth("session", cookie_validator)
|
||||||
|
multi = MultiAuth(bearer, cookie)
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(multi)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get("/me", headers={"Authorization": f"Bearer {VALID_TOKEN}"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"user": "alice"}
|
||||||
|
|
||||||
|
def test_second_source_matches_when_first_absent(self):
|
||||||
|
bearer = BearerTokenAuth(simple_validator)
|
||||||
|
cookie = CookieAuth("session", cookie_validator)
|
||||||
|
multi = MultiAuth(bearer, cookie)
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(multi)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
# No Authorization header — falls through to cookie
|
||||||
|
response = client.get("/me", cookies={"session": VALID_COOKIE})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"session": VALID_COOKIE}
|
||||||
|
|
||||||
|
def test_no_source_matches_returns_401(self):
|
||||||
|
bearer = BearerTokenAuth(simple_validator)
|
||||||
|
cookie = CookieAuth("session", cookie_validator)
|
||||||
|
multi = MultiAuth(bearer, cookie)
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(multi)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get("/me")
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_invalid_credential_does_not_fallthrough(self):
|
||||||
|
"""If a credential is found but invalid, the next source is NOT tried."""
|
||||||
|
second_called: list[bool] = []
|
||||||
|
|
||||||
|
async def tracking_validator(credential: str) -> dict:
|
||||||
|
second_called.append(True)
|
||||||
|
return {"from": "second"}
|
||||||
|
|
||||||
|
bearer = BearerTokenAuth(simple_validator) # raises on wrong token
|
||||||
|
cookie = CookieAuth("session", tracking_validator)
|
||||||
|
multi = MultiAuth(bearer, cookie)
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(multi)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
# Bearer credential present but wrong — should NOT try cookie
|
||||||
|
response = client.get(
|
||||||
|
"/me",
|
||||||
|
headers={"Authorization": "Bearer wrong"},
|
||||||
|
cookies={"session": VALID_COOKIE},
|
||||||
|
)
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert second_called == [] # cookie validator was never called
|
||||||
|
|
||||||
|
def test_prefix_routes_to_correct_source(self):
|
||||||
|
"""Prefix-based dispatch: only the matching source's validator is called."""
|
||||||
|
user_calls: list[str] = []
|
||||||
|
org_calls: list[str] = []
|
||||||
|
|
||||||
|
async def user_validator(credential: str) -> dict:
|
||||||
|
user_calls.append(credential)
|
||||||
|
return {"type": "user", "id": credential}
|
||||||
|
|
||||||
|
async def org_validator(credential: str) -> dict:
|
||||||
|
org_calls.append(credential)
|
||||||
|
return {"type": "org", "id": credential}
|
||||||
|
|
||||||
|
user_bearer = BearerTokenAuth(user_validator, prefix="user_")
|
||||||
|
org_bearer = BearerTokenAuth(org_validator, prefix="org_")
|
||||||
|
multi = MultiAuth(user_bearer, org_bearer)
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(multi)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
|
||||||
|
response = client.get("/me", headers={"Authorization": "Bearer user_alice"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"type": "user", "id": "user_alice"}
|
||||||
|
assert user_calls == ["user_alice"]
|
||||||
|
assert org_calls == []
|
||||||
|
|
||||||
|
user_calls.clear()
|
||||||
|
|
||||||
|
response = client.get("/me", headers={"Authorization": "Bearer org_acme"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"type": "org", "id": "org_acme"}
|
||||||
|
assert user_calls == []
|
||||||
|
assert org_calls == ["org_acme"]
|
||||||
|
|
||||||
|
def test_require_returns_new_multi_auth(self):
|
||||||
|
from fastapi_toolsets.security.sources import MultiAuth as MultiAuthClass
|
||||||
|
|
||||||
|
bearer = BearerTokenAuth(role_validator)
|
||||||
|
multi = MultiAuth(bearer)
|
||||||
|
derived = multi.require(role="admin")
|
||||||
|
assert isinstance(derived, MultiAuthClass)
|
||||||
|
assert derived is not multi
|
||||||
|
|
||||||
|
def test_require_forwards_kwargs_to_sources(self):
|
||||||
|
"""multi.require() propagates to all sources that support it."""
|
||||||
|
bearer = BearerTokenAuth(role_validator)
|
||||||
|
multi = MultiAuth(bearer)
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/admin")
|
||||||
|
async def admin(user=Security(multi.require(role="admin"))):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get(
|
||||||
|
"/admin", headers={"Authorization": f"Bearer {VALID_TOKEN}"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"user": "alice", "role": "admin"}
|
||||||
|
|
||||||
|
def test_require_skips_sources_without_require(self):
|
||||||
|
"""Sources without require() are passed through unchanged."""
|
||||||
|
header_auth = _HeaderAuth(secret="s3cr3t")
|
||||||
|
multi = MultiAuth(header_auth)
|
||||||
|
derived = multi.require(role="admin")
|
||||||
|
assert derived._sources[0] is header_auth
|
||||||
|
|
||||||
|
def test_require_does_not_mutate_original(self):
|
||||||
|
bearer = BearerTokenAuth(role_validator, role="user")
|
||||||
|
multi = MultiAuth(bearer)
|
||||||
|
multi.require(role="admin")
|
||||||
|
assert bearer._kwargs == {"role": "user"}
|
||||||
|
|
||||||
|
def test_require_mixed_sources(self):
|
||||||
|
"""require() applies to sources with require(), skips those without."""
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
bearer = BearerTokenAuth(role_validator)
|
||||||
|
header_auth = _HeaderAuth(secret="s3cr3t")
|
||||||
|
multi = MultiAuth(bearer, header_auth)
|
||||||
|
derived = multi.require(role="admin")
|
||||||
|
# bearer got require() applied, header_auth passed through
|
||||||
|
assert cast(BearerTokenAuth, derived._sources[0])._kwargs == {"role": "admin"}
|
||||||
|
assert derived._sources[1] is header_auth
|
||||||
|
|
||||||
|
|
||||||
|
class TestRequire:
|
||||||
|
def test_bearer_require_forwards_kwargs(self):
|
||||||
|
"""require() creates a new instance that passes merged kwargs to validator."""
|
||||||
|
bearer = BearerTokenAuth(role_validator)
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/admin")
|
||||||
|
async def admin(user=Security(bearer.require(role="admin"))):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get(
|
||||||
|
"/admin", headers={"Authorization": f"Bearer {VALID_TOKEN}"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"user": "alice", "role": "admin"}
|
||||||
|
|
||||||
|
def test_bearer_require_overrides_existing_kwarg(self):
|
||||||
|
"""require() kwargs override kwargs set at instantiation."""
|
||||||
|
bearer = BearerTokenAuth(role_validator, role="user")
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/admin")
|
||||||
|
async def admin(user=Security(bearer.require(role="admin"))):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get(
|
||||||
|
"/admin", headers={"Authorization": f"Bearer {VALID_TOKEN}"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["role"] == "admin"
|
||||||
|
|
||||||
|
def test_bearer_require_preserves_prefix(self):
|
||||||
|
"""require() keeps the prefix of the original instance."""
|
||||||
|
bearer = BearerTokenAuth(role_validator, prefix="user_")
|
||||||
|
derived = bearer.require(role="admin")
|
||||||
|
assert derived._prefix == "user_"
|
||||||
|
|
||||||
|
def test_bearer_require_does_not_mutate_original(self):
|
||||||
|
"""require() returns a new instance — original kwargs are unchanged."""
|
||||||
|
bearer = BearerTokenAuth(role_validator, role="user")
|
||||||
|
bearer.require(role="admin")
|
||||||
|
assert bearer._kwargs == {"role": "user"}
|
||||||
|
|
||||||
|
def test_cookie_require_forwards_kwargs(self):
|
||||||
|
async def scoped_validator(value: str, *, scope: str) -> dict:
|
||||||
|
if value != VALID_COOKIE:
|
||||||
|
raise UnauthorizedError()
|
||||||
|
return {"session": value, "scope": scope}
|
||||||
|
|
||||||
|
cookie = CookieAuth("session", scoped_validator)
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/admin")
|
||||||
|
async def admin(user=Security(cookie.require(scope="admin"))):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get("/admin", cookies={"session": VALID_COOKIE})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"session": VALID_COOKIE, "scope": "admin"}
|
||||||
|
|
||||||
|
def test_cookie_require_preserves_name(self):
|
||||||
|
cookie = CookieAuth("session", cookie_validator)
|
||||||
|
derived = cookie.require(scope="admin")
|
||||||
|
assert derived._name == "session"
|
||||||
|
|
||||||
|
def test_bearer_require_in_multi_auth(self):
|
||||||
|
"""require() instances work seamlessly inside MultiAuth."""
|
||||||
|
PREFIXED_TOKEN = f"user_{VALID_TOKEN}"
|
||||||
|
|
||||||
|
async def prefixed_role_validator(credential: str, *, role: str) -> dict:
|
||||||
|
if credential != PREFIXED_TOKEN:
|
||||||
|
raise UnauthorizedError()
|
||||||
|
return {"user": "alice", "role": role}
|
||||||
|
|
||||||
|
bearer = BearerTokenAuth(prefixed_role_validator, prefix="user_")
|
||||||
|
multi = MultiAuth(bearer.require(role="admin"))
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/admin")
|
||||||
|
async def admin(user=Security(multi)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get(
|
||||||
|
"/admin", headers={"Authorization": f"Bearer {PREFIXED_TOKEN}"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"user": "alice", "role": "admin"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncValidators:
|
||||||
|
"""Sync (non-async) validators — covers the sync path in _call_validator."""
|
||||||
|
|
||||||
|
def test_bearer_sync_validator(self):
|
||||||
|
def sync_validator(credential: str) -> dict:
|
||||||
|
if credential != VALID_TOKEN:
|
||||||
|
raise UnauthorizedError()
|
||||||
|
return {"user": "alice"}
|
||||||
|
|
||||||
|
bearer = BearerTokenAuth(sync_validator)
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(bearer)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get("/me", headers={"Authorization": f"Bearer {VALID_TOKEN}"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"user": "alice"}
|
||||||
|
|
||||||
|
def test_sync_validator_via_authenticate(self):
|
||||||
|
"""authenticate() with sync validator (MultiAuth path)."""
|
||||||
|
|
||||||
|
def sync_validator(credential: str) -> dict:
|
||||||
|
if credential != VALID_TOKEN:
|
||||||
|
raise UnauthorizedError()
|
||||||
|
return {"user": "alice"}
|
||||||
|
|
||||||
|
bearer = BearerTokenAuth(sync_validator)
|
||||||
|
multi = MultiAuth(bearer)
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(multi)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get("/me", headers={"Authorization": f"Bearer {VALID_TOKEN}"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"user": "alice"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestCookieAuthSigned:
|
||||||
|
"""CookieAuth with HMAC-SHA256 signed cookies (secret_key path)."""
|
||||||
|
|
||||||
|
SECRET = "test-hmac-secret"
|
||||||
|
|
||||||
|
def test_valid_signed_cookie_via_set_cookie(self):
|
||||||
|
"""set_cookie signs the value; the signed cookie is verified on read."""
|
||||||
|
from fastapi import Response
|
||||||
|
|
||||||
|
auth = CookieAuth("session", cookie_validator, secret_key=self.SECRET)
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/login")
|
||||||
|
async def login(response: Response):
|
||||||
|
auth.set_cookie(response, VALID_COOKIE)
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(auth)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
with TestClient(_app(setup)) as client:
|
||||||
|
client.get("/login")
|
||||||
|
response = client.get("/me")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"session": VALID_COOKIE}
|
||||||
|
|
||||||
|
def test_tampered_signature_returns_401(self):
|
||||||
|
"""A cookie whose HMAC signature has been modified is rejected."""
|
||||||
|
import base64 as _b64
|
||||||
|
import json as _json
|
||||||
|
import time as _time
|
||||||
|
|
||||||
|
auth = CookieAuth("session", cookie_validator, secret_key=self.SECRET)
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(auth)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
data = _b64.urlsafe_b64encode(
|
||||||
|
_json.dumps({"v": VALID_COOKIE, "exp": int(_time.time()) + 9999}).encode()
|
||||||
|
).decode()
|
||||||
|
response = client.get("/me", cookies={"session": f"{data}.invalidsig"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_expired_signed_cookie_returns_401(self):
|
||||||
|
"""A signed cookie past its expiry timestamp is rejected."""
|
||||||
|
import base64 as _b64
|
||||||
|
import hashlib as _hashlib
|
||||||
|
import hmac as _hmac
|
||||||
|
import json as _json
|
||||||
|
import time as _time
|
||||||
|
|
||||||
|
auth = CookieAuth("session", cookie_validator, secret_key=self.SECRET)
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(auth)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
data = _b64.urlsafe_b64encode(
|
||||||
|
_json.dumps({"v": VALID_COOKIE, "exp": int(_time.time()) - 1}).encode()
|
||||||
|
).decode()
|
||||||
|
sig = _hmac.new(
|
||||||
|
self.SECRET.encode(), data.encode(), _hashlib.sha256
|
||||||
|
).hexdigest()
|
||||||
|
response = client.get("/me", cookies={"session": f"{data}.{sig}"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_invalid_json_payload_returns_401(self):
|
||||||
|
"""A signed cookie whose payload is not valid JSON is rejected."""
|
||||||
|
import base64 as _b64
|
||||||
|
import hashlib as _hashlib
|
||||||
|
import hmac as _hmac
|
||||||
|
|
||||||
|
auth = CookieAuth("session", cookie_validator, secret_key=self.SECRET)
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(auth)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
data = _b64.urlsafe_b64encode(b"not-valid-json").decode()
|
||||||
|
sig = _hmac.new(
|
||||||
|
self.SECRET.encode(), data.encode(), _hashlib.sha256
|
||||||
|
).hexdigest()
|
||||||
|
response = client.get("/me", cookies={"session": f"{data}.{sig}"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_malformed_cookie_no_dot_returns_401(self):
|
||||||
|
"""A signed cookie without the dot separator is rejected."""
|
||||||
|
auth = CookieAuth("session", cookie_validator, secret_key=self.SECRET)
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(auth)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get("/me", cookies={"session": "nodothere"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_set_cookie_without_secret(self):
|
||||||
|
"""set_cookie without secret_key writes the raw value."""
|
||||||
|
from starlette.responses import Response as StarletteResponse
|
||||||
|
|
||||||
|
auth = CookieAuth("session", cookie_validator)
|
||||||
|
response = StarletteResponse()
|
||||||
|
auth.set_cookie(response, "rawvalue")
|
||||||
|
assert "session=rawvalue" in response.headers["set-cookie"]
|
||||||
|
|
||||||
|
def test_delete_cookie(self):
|
||||||
|
"""delete_cookie produces a Set-Cookie header that clears the session."""
|
||||||
|
from starlette.responses import Response as StarletteResponse
|
||||||
|
|
||||||
|
auth = CookieAuth("session", cookie_validator)
|
||||||
|
response = StarletteResponse()
|
||||||
|
auth.delete_cookie(response)
|
||||||
|
assert "session" in response.headers["set-cookie"]
|
||||||
|
|
||||||
|
|
||||||
|
# Minimal concrete subclass used only in tests below.
|
||||||
|
class _HeaderAuth(AuthSource):
|
||||||
|
"""Reads a custom X-Token header — no FastAPI security scheme."""
|
||||||
|
|
||||||
|
def __init__(self, secret: str) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._secret = secret
|
||||||
|
|
||||||
|
async def extract(self, request) -> str | None:
|
||||||
|
return request.headers.get("X-Token") or None
|
||||||
|
|
||||||
|
async def authenticate(self, credential: str) -> dict:
|
||||||
|
if credential != self._secret:
|
||||||
|
raise UnauthorizedError()
|
||||||
|
return {"token": credential}
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthSource:
|
||||||
|
def test_cannot_instantiate_abstract_class(self):
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
AuthSource()
|
||||||
|
|
||||||
|
def test_builtin_classes_are_auth_sources(self):
|
||||||
|
bearer = BearerTokenAuth(simple_validator)
|
||||||
|
cookie = CookieAuth("session", cookie_validator)
|
||||||
|
api_key = APIKeyHeaderAuth("X-API-Key", simple_validator)
|
||||||
|
assert isinstance(bearer, AuthSource)
|
||||||
|
assert isinstance(cookie, AuthSource)
|
||||||
|
assert isinstance(api_key, AuthSource)
|
||||||
|
|
||||||
|
def test_custom_source_standalone_valid(self):
|
||||||
|
"""Default __call__ wires extract + authenticate via Request injection."""
|
||||||
|
auth = _HeaderAuth(secret="s3cr3t")
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(auth)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get("/me", headers={"X-Token": "s3cr3t"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"token": "s3cr3t"}
|
||||||
|
|
||||||
|
def test_custom_source_standalone_missing_credential(self):
|
||||||
|
auth = _HeaderAuth(secret="s3cr3t")
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(auth)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get("/me") # no X-Token header
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_custom_source_standalone_invalid_credential(self):
|
||||||
|
auth = _HeaderAuth(secret="s3cr3t")
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(auth)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
response = client.get("/me", headers={"X-Token": "wrong"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_custom_source_in_multi_auth(self):
|
||||||
|
"""Custom AuthSource works transparently inside MultiAuth."""
|
||||||
|
header_auth = _HeaderAuth(secret="s3cr3t")
|
||||||
|
bearer = BearerTokenAuth(simple_validator)
|
||||||
|
multi = MultiAuth(bearer, header_auth)
|
||||||
|
|
||||||
|
def setup(app: FastAPI):
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user=Security(multi)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
client = TestClient(_app(setup))
|
||||||
|
|
||||||
|
# Bearer matches first
|
||||||
|
response = client.get("/me", headers={"Authorization": f"Bearer {VALID_TOKEN}"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"user": "alice"}
|
||||||
|
|
||||||
|
# No bearer → falls through to custom header source
|
||||||
|
response = client.get("/me", headers={"X-Token": "s3cr3t"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"token": "s3cr3t"}
|
||||||
Reference in New Issue
Block a user