feat: add security module

This commit is contained in:
2026-03-04 11:02:54 -05:00
parent 9b74f162ab
commit c785b4c1ef
11 changed files with 1813 additions and 0 deletions

View File

@@ -0,0 +1,15 @@
"""Authentication helpers for FastAPI using Security()."""
from .abc import AuthSource
from .oauth import decode_oauth_state, encode_oauth_state
from .sources import APIKeyHeaderAuth, BearerTokenAuth, CookieAuth, MultiAuth
__all__ = [
"APIKeyHeaderAuth",
"AuthSource",
"BearerTokenAuth",
"CookieAuth",
"MultiAuth",
"decode_oauth_state",
"encode_oauth_state",
]

View File

@@ -0,0 +1,51 @@
"""Abstract base class for authentication sources."""
import inspect
from abc import ABC, abstractmethod
from typing import Any, Callable
from fastapi import Request
from fastapi.security import SecurityScopes
from fastapi_toolsets.exceptions import UnauthorizedError
async def _call_validator(
validator: Callable[..., Any], *args: Any, **kwargs: Any
) -> Any:
"""Call *validator* with *args* and *kwargs*, awaiting it if it is a coroutine function."""
if inspect.iscoroutinefunction(validator):
return await validator(*args, **kwargs)
return validator(*args, **kwargs)
class AuthSource(ABC):
"""Abstract base class for authentication sources."""
def __init__(self) -> None:
"""Set up the default FastAPI dependency signature."""
source = self
async def _call(
request: Request,
security_scopes: SecurityScopes, # noqa: ARG001
) -> Any:
credential = await source.extract(request)
if credential is None:
raise UnauthorizedError()
return await source.authenticate(credential)
self._call_fn: Callable[..., Any] = _call
self.__signature__ = inspect.signature(_call)
@abstractmethod
async def extract(self, request: Request) -> str | None:
"""Extract the raw credential from the request without validating."""
@abstractmethod
async def authenticate(self, credential: str) -> Any:
"""Validate a credential and return the authenticated identity."""
async def __call__(self, **kwargs: Any) -> Any:
"""FastAPI dependency dispatch."""
return await self._call_fn(**kwargs)

View File

@@ -0,0 +1,24 @@
"""OAuth 2.0 / OIDC helper utilities."""
import base64
def encode_oauth_state(url: str) -> str:
"""Base64url-encode a URL to embed as an OAuth ``state`` parameter."""
return base64.urlsafe_b64encode(url.encode()).decode()
def decode_oauth_state(state: str | None, *, fallback: str) -> str:
"""Decode a base64url OAuth ``state`` parameter.
Handles missing padding (some providers strip ``=``).
Returns *fallback* if *state* is absent, the literal string ``"null"``,
or cannot be decoded.
"""
if not state or state == "null":
return fallback
try:
padded = state + "=" * (4 - len(state) % 4)
return base64.urlsafe_b64decode(padded).decode()
except Exception:
return fallback

View File

@@ -0,0 +1,8 @@
"""Built-in authentication source implementations."""
from .header import APIKeyHeaderAuth
from .bearer import BearerTokenAuth
from .cookie import CookieAuth
from .multi import MultiAuth
__all__ = ["APIKeyHeaderAuth", "BearerTokenAuth", "CookieAuth", "MultiAuth"]

View File

@@ -0,0 +1,122 @@
"""Bearer token authentication source."""
import inspect
import secrets
from typing import Annotated, Any, Callable
from fastapi import Depends
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer, SecurityScopes
from fastapi_toolsets.exceptions import UnauthorizedError
from ..abc import AuthSource, _call_validator
class BearerTokenAuth(AuthSource):
"""Bearer token authentication source.
Wraps :class:`fastapi.security.HTTPBearer` for OpenAPI documentation.
The validator is called as ``await validator(credential, **kwargs)``
where ``kwargs`` are the extra keyword arguments provided at instantiation.
Args:
validator: Sync or async callable that receives the credential and any
extra keyword arguments, and returns the authenticated identity
(e.g. a ``User`` model). Should raise
:class:`~fastapi_toolsets.exceptions.UnauthorizedError` on failure.
prefix: Optional token prefix (e.g. ``"user_"``). If set, only tokens
whose value starts with this prefix are matched. The prefix is
**kept** in the value passed to the validator — store and compare
tokens with their prefix included. Use :meth:`generate_token` to
create correctly-prefixed tokens. This enables multiple
``BearerTokenAuth`` instances in the same app (e.g. ``"user_"``
for user tokens, ``"org_"`` for org tokens).
**kwargs: Extra keyword arguments forwarded to the validator on every
call (e.g. ``role=Role.ADMIN``).
"""
def __init__(
self,
validator: Callable[..., Any],
*,
prefix: str | None = None,
**kwargs: Any,
) -> None:
self._validator = validator
self._prefix = prefix
self._kwargs = kwargs
self._scheme = HTTPBearer(auto_error=False)
_scheme = self._scheme
_validator = validator
_kwargs = kwargs
_prefix = prefix
async def _call(
security_scopes: SecurityScopes, # noqa: ARG001
credentials: Annotated[
HTTPAuthorizationCredentials | None, Depends(_scheme)
] = None,
) -> Any:
if credentials is None:
raise UnauthorizedError()
token = credentials.credentials
if _prefix is not None and not token.startswith(_prefix):
raise UnauthorizedError()
return await _call_validator(_validator, token, **_kwargs)
self._call_fn = _call
self.__signature__ = inspect.signature(_call)
async def extract(self, request: Any) -> str | None:
"""Extract the raw credential from the request without validating.
Returns ``None`` if no ``Authorization: Bearer`` header is present,
the token is empty, or the token does not match the configured prefix.
The prefix is included in the returned value.
"""
auth = request.headers.get("Authorization", "")
if not auth.startswith("Bearer "):
return None
token = auth[7:]
if not token:
return None
if self._prefix is not None and not token.startswith(self._prefix):
return None
return token
async def authenticate(self, credential: str) -> Any:
"""Validate a credential and return the identity.
Calls ``await validator(credential, **kwargs)`` where ``kwargs`` are
the extra keyword arguments provided at instantiation.
"""
return await _call_validator(self._validator, credential, **self._kwargs)
def require(self, **kwargs: Any) -> "BearerTokenAuth":
"""Return a new instance with additional (or overriding) validator kwargs."""
return BearerTokenAuth(
self._validator,
prefix=self._prefix,
**{**self._kwargs, **kwargs},
)
def generate_token(self, nbytes: int = 32) -> str:
"""Generate a secure random token for this auth source.
Returns a URL-safe random token. If a prefix is configured it is
prepended — the returned value is what you store in your database
and return to the client as-is.
Args:
nbytes: Number of random bytes before base64 encoding. The
resulting string is ``ceil(nbytes * 4 / 3)`` characters
(43 chars for the default 32 bytes). Defaults to 32.
Returns:
A ready-to-use token string (e.g. ``"user_Xk3..."``).
"""
token = secrets.token_urlsafe(nbytes)
if self._prefix is not None:
return f"{self._prefix}{token}"
return token

View File

@@ -0,0 +1,142 @@
"""Cookie-based authentication source."""
import base64
import hashlib
import hmac
import inspect
import json
import time
from typing import Annotated, Any, Callable
from fastapi import Depends, Request, Response
from fastapi.security import APIKeyCookie, SecurityScopes
from fastapi_toolsets.exceptions import UnauthorizedError
from ..abc import AuthSource, _call_validator
class CookieAuth(AuthSource):
"""Cookie-based authentication source.
Wraps :class:`fastapi.security.APIKeyCookie` for OpenAPI documentation.
Optionally signs the cookie with HMAC-SHA256 to provide stateless, tamper-
proof sessions without any database entry.
Args:
name: Cookie name.
validator: Sync or async callable that receives the cookie value
(plain, after signature verification when ``secret_key`` is set)
and any extra keyword arguments, and returns the authenticated
identity.
secret_key: When provided, the cookie is HMAC-SHA256 signed.
:meth:`set_cookie` embeds an expiry and signs the payload;
:meth:`extract` verifies the signature and expiry before handing
the plain value to the validator. When ``None`` (default), the raw
cookie value is passed to the validator as-is.
ttl: Cookie lifetime in seconds (default 24 h). Only used when
``secret_key`` is set.
**kwargs: Extra keyword arguments forwarded to the validator on every
call (e.g. ``role=Role.ADMIN``).
"""
def __init__(
self,
name: str,
validator: Callable[..., Any],
*,
secret_key: str | None = None,
ttl: int = 86400,
**kwargs: Any,
) -> None:
self._name = name
self._validator = validator
self._secret_key = secret_key
self._ttl = ttl
self._kwargs = kwargs
self._scheme = APIKeyCookie(name=name, auto_error=False)
_scheme = self._scheme
_self = self
_kwargs = kwargs
async def _call(
security_scopes: SecurityScopes, # noqa: ARG001
value: Annotated[str | None, Depends(_scheme)] = None,
) -> Any:
if value is None:
raise UnauthorizedError()
plain = _self._verify(value)
return await _call_validator(_self._validator, plain, **_kwargs)
self._call_fn = _call
self.__signature__ = inspect.signature(_call)
def _hmac(self, data: str) -> str:
assert self._secret_key is not None
return hmac.new(
self._secret_key.encode(), data.encode(), hashlib.sha256
).hexdigest()
def _sign(self, value: str) -> str:
data = base64.urlsafe_b64encode(
json.dumps({"v": value, "exp": int(time.time()) + self._ttl}).encode()
).decode()
return f"{data}.{self._hmac(data)}"
def _verify(self, cookie_value: str) -> str:
"""Return the plain value, verifying HMAC + expiry when signed."""
if not self._secret_key:
return cookie_value
try:
data, sig = cookie_value.rsplit(".", 1)
except ValueError:
raise UnauthorizedError()
if not hmac.compare_digest(self._hmac(data), sig):
raise UnauthorizedError()
try:
payload = json.loads(base64.urlsafe_b64decode(data))
value: str = payload["v"]
exp: int = payload["exp"]
except Exception:
raise UnauthorizedError()
if exp < int(time.time()):
raise UnauthorizedError()
return value
async def extract(self, request: Request) -> str | None:
return request.cookies.get(self._name)
async def authenticate(self, credential: str) -> Any:
plain = self._verify(credential)
return await _call_validator(self._validator, plain, **self._kwargs)
def require(self, **kwargs: Any) -> "CookieAuth":
"""Return a new instance with additional (or overriding) validator kwargs."""
return CookieAuth(
self._name,
self._validator,
secret_key=self._secret_key,
ttl=self._ttl,
**{**self._kwargs, **kwargs},
)
def set_cookie(self, response: Response, value: str) -> None:
"""Attach the cookie to *response*, signing it when ``secret_key`` is set."""
cookie_value = self._sign(value) if self._secret_key else value
response.set_cookie(
self._name,
cookie_value,
httponly=True,
samesite="lax",
max_age=self._ttl,
)
def delete_cookie(self, response: Response) -> None:
"""Clear the session cookie (logout)."""
response.delete_cookie(self._name, httponly=True, samesite="lax")

View File

@@ -0,0 +1,71 @@
"""API key header authentication source."""
import inspect
from typing import Annotated, Any, Callable
from fastapi import Depends, Request
from fastapi.security import APIKeyHeader, SecurityScopes
from fastapi_toolsets.exceptions import UnauthorizedError
from ..abc import AuthSource, _call_validator
class APIKeyHeaderAuth(AuthSource):
"""API key header authentication source.
Wraps :class:`fastapi.security.APIKeyHeader` for OpenAPI documentation.
The validator is called as ``await validator(api_key, **kwargs)``
where ``kwargs`` are the extra keyword arguments provided at instantiation.
Args:
name: HTTP header name that carries the API key (e.g. ``"X-API-Key"``).
validator: Sync or async callable that receives the API key and any
extra keyword arguments, and returns the authenticated identity.
Should raise :class:`~fastapi_toolsets.exceptions.UnauthorizedError`
on failure.
**kwargs: Extra keyword arguments forwarded to the validator on every
call (e.g. ``role=Role.ADMIN``).
"""
def __init__(
self,
name: str,
validator: Callable[..., Any],
**kwargs: Any,
) -> None:
self._name = name
self._validator = validator
self._kwargs = kwargs
self._scheme = APIKeyHeader(name=name, auto_error=False)
_scheme = self._scheme
_validator = validator
_kwargs = kwargs
async def _call(
security_scopes: SecurityScopes, # noqa: ARG001
api_key: Annotated[str | None, Depends(_scheme)] = None,
) -> Any:
if api_key is None:
raise UnauthorizedError()
return await _call_validator(_validator, api_key, **_kwargs)
self._call_fn = _call
self.__signature__ = inspect.signature(_call)
async def extract(self, request: Request) -> str | None:
"""Extract the API key from the configured header."""
return request.headers.get(self._name) or None
async def authenticate(self, credential: str) -> Any:
"""Validate a credential and return the identity."""
return await _call_validator(self._validator, credential, **self._kwargs)
def require(self, **kwargs: Any) -> "APIKeyHeaderAuth":
"""Return a new instance with additional (or overriding) validator kwargs."""
return APIKeyHeaderAuth(
self._name,
self._validator,
**{**self._kwargs, **kwargs},
)

View File

@@ -0,0 +1,121 @@
"""MultiAuth: combine multiple authentication sources into a single callable."""
import inspect
from typing import Any, cast
from fastapi import Request
from fastapi.security import SecurityScopes
from fastapi_toolsets.exceptions import UnauthorizedError
from ..abc import AuthSource
class MultiAuth:
"""Combine multiple authentication sources into a single callable.
Sources are tried in order; the first one whose
:meth:`~AuthSource.extract` returns a non-``None`` credential wins.
Its :meth:`~AuthSource.authenticate` is called and the result returned.
If a credential is found but the validator raises, the exception propagates
immediately — the remaining sources are **not** tried. This prevents
silent fallthrough on invalid credentials.
If no source provides a credential,
:class:`~fastapi_toolsets.exceptions.UnauthorizedError` is raised.
The :meth:`~AuthSource.extract` method of each source performs only
string matching (no I/O), so prefix-based dispatch is essentially free.
Any :class:`~AuthSource` subclass — including user-defined ones — can be
passed as a source.
Args:
*sources: Auth source instances to try in order.
Example::
user_bearer = BearerTokenAuth(verify_user, prefix="user_")
org_bearer = BearerTokenAuth(verify_org, prefix="org_")
cookie = CookieAuth("session", verify_session)
multi = MultiAuth(user_bearer, org_bearer, cookie)
@app.get("/data")
async def data_route(user = Security(multi)):
return user
# Apply a shared requirement to all sources at once
@app.get("/admin")
async def admin_route(user = Security(multi.require(role=Role.ADMIN))):
return user
"""
def __init__(self, *sources: AuthSource) -> None:
self._sources = sources
_sources = sources
async def _call(
request: Request,
security_scopes: SecurityScopes, # noqa: ARG001
**kwargs: Any, # noqa: ARG001 — absorbs scheme values injected by FastAPI
) -> Any:
for source in _sources:
credential = await source.extract(request)
if credential is not None:
return await source.authenticate(credential)
raise UnauthorizedError()
self._call_fn = _call
# Build a merged signature that includes the security-scheme Depends()
# parameters from every source so FastAPI registers them in OpenAPI docs.
seen: set[str] = {"request", "security_scopes"}
merged: list[inspect.Parameter] = [
inspect.Parameter(
"request",
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=Request,
),
inspect.Parameter(
"security_scopes",
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=SecurityScopes,
),
]
for i, source in enumerate(sources):
for name, param in inspect.signature(source).parameters.items():
if name in seen:
continue
merged.append(param.replace(name=f"_s{i}_{name}"))
seen.add(name)
self.__signature__ = inspect.Signature(merged, return_annotation=Any)
async def __call__(self, **kwargs: Any) -> Any:
return await self._call_fn(**kwargs)
def require(self, **kwargs: Any) -> "MultiAuth":
"""Return a new :class:`MultiAuth` with kwargs forwarded to each source.
Calls ``.require(**kwargs)`` on every source that supports it. Sources
that do not implement ``.require()`` (e.g. custom :class:`~AuthSource`
subclasses) are passed through unchanged.
New kwargs are merged over each source's existing kwargs — new values
win on conflict::
multi = MultiAuth(bearer, cookie)
@app.get("/admin")
async def admin(user = Security(multi.require(role=Role.ADMIN))):
return user
"""
new_sources = tuple(
cast(Any, source).require(**kwargs)
if hasattr(source, "require")
else source
for source in self._sources
)
return MultiAuth(*new_sources)