From 59231bd5d026f1a8805d87ed0e71c191f2d03ae5 Mon Sep 17 00:00:00 2001 From: d3vyce Date: Wed, 4 Mar 2026 11:02:54 -0500 Subject: [PATCH] feat: add security module --- docs/module/security.md | 267 +++++ docs/reference/security.md | 28 + src/fastapi_toolsets/security/__init__.py | 15 + src/fastapi_toolsets/security/abc.py | 51 + src/fastapi_toolsets/security/oauth.py | 24 + .../security/sources/__init__.py | 8 + .../security/sources/bearer.py | 122 +++ .../security/sources/cookie.py | 142 +++ .../security/sources/header.py | 71 ++ .../security/sources/multi.py | 121 +++ tests/test_security.py | 964 ++++++++++++++++++ 11 files changed, 1813 insertions(+) create mode 100644 docs/module/security.md create mode 100644 docs/reference/security.md create mode 100644 src/fastapi_toolsets/security/__init__.py create mode 100644 src/fastapi_toolsets/security/abc.py create mode 100644 src/fastapi_toolsets/security/oauth.py create mode 100644 src/fastapi_toolsets/security/sources/__init__.py create mode 100644 src/fastapi_toolsets/security/sources/bearer.py create mode 100644 src/fastapi_toolsets/security/sources/cookie.py create mode 100644 src/fastapi_toolsets/security/sources/header.py create mode 100644 src/fastapi_toolsets/security/sources/multi.py create mode 100644 tests/test_security.py diff --git a/docs/module/security.md b/docs/module/security.md new file mode 100644 index 0000000..6ddddb5 --- /dev/null +++ b/docs/module/security.md @@ -0,0 +1,267 @@ +# Security + +Composable authentication helpers for FastAPI that use `Security()` for OpenAPI documentation and accept user-provided validator functions with full type flexibility. + +## Overview + +The `security` module provides four auth source classes and a `MultiAuth` factory. Each class wraps a FastAPI security scheme for OpenAPI and accepts a validator function called as: + +```python +await validator(credential, **kwargs) +``` + +where `kwargs` are the extra keyword arguments provided at instantiation (roles, permissions, enums, etc.). The validator returns the authenticated identity (e.g. a `User` model) which becomes the route dependency value. + +```python +from fastapi import Security +from fastapi_toolsets.security import BearerTokenAuth + +async def verify_token(token: str, *, role: str) -> User: + user = await db.get_by_token(token) + if not user or user.role != role: + raise UnauthorizedError() + return user + +bearer_admin = BearerTokenAuth(verify_token, role="admin") + +@app.get("/admin") +async def admin_route(user: User = Security(bearer_admin)): + return user +``` + +## Auth sources + +### [`BearerTokenAuth`](../reference/security.md#fastapi_toolsets.security.BearerTokenAuth) + +Reads the `Authorization: Bearer ` header. Wraps `HTTPBearer` for OpenAPI. + +```python +from fastapi_toolsets.security import BearerTokenAuth + +bearer = BearerTokenAuth(validator=verify_token) + +@app.get("/me") +async def me(user: User = Security(bearer)): + return user +``` + +#### Token prefix + +The optional `prefix` parameter restricts a `BearerTokenAuth` instance to tokens +that start with a given string. The prefix is **kept** in the value passed to the +validator — store and compare tokens with their prefix included. + +This lets you deploy multiple `BearerTokenAuth` instances in the same application +and disambiguate them efficiently in `MultiAuth`: + +```python +user_bearer = BearerTokenAuth(verify_user, prefix="user_") # matches "Bearer user_..." +org_bearer = BearerTokenAuth(verify_org, prefix="org_") # matches "Bearer org_..." +``` + +Use [`generate_token()`](#token-generation) to create correctly-prefixed tokens. + +#### Token generation + +`BearerTokenAuth.generate_token()` produces a secure random token ready to store +in your database and return to the client. If a prefix is configured it is +prepended automatically: + +```python +bearer = BearerTokenAuth(verify_token, prefix="user_") + +token = bearer.generate_token() # e.g. "user_Xk3mN..." +await db.store_token(user_id, token) +return {"access_token": token, "token_type": "bearer"} +``` + +The client sends `Authorization: Bearer user_Xk3mN...` and the validator receives +the full token (prefix included) to compare against the stored value. + +### [`CookieAuth`](../reference/security.md#fastapi_toolsets.security.CookieAuth) + +Reads a named cookie. Wraps `APIKeyCookie` for OpenAPI. + +```python +from fastapi_toolsets.security import CookieAuth + +cookie_auth = CookieAuth("session", validator=verify_session) + +@app.get("/me") +async def me(user: User = Security(cookie_auth)): + return user +``` + +### [`OAuth2Auth`](../reference/security.md#fastapi_toolsets.security.OAuth2Auth) + +Reads the `Authorization: Bearer ` header and registers the token endpoint +in OpenAPI via `OAuth2PasswordBearer`. + +```python +from fastapi_toolsets.security import OAuth2Auth + +oauth2_auth = OAuth2Auth(token_url="/token", validator=verify_token) + +@app.get("/me") +async def me(user: User = Security(oauth2_auth)): + return user +``` + +### [`OpenIDAuth`](../reference/security.md#fastapi_toolsets.security.OpenIDAuth) + +Reads the `Authorization: Bearer ` header and registers the OpenID Connect +discovery URL in OpenAPI via `OpenIdConnect`. Token validation is fully delegated +to your validator — use any OIDC / JWT library (`authlib`, `python-jose`, `PyJWT`). + +```python +from fastapi_toolsets.security import OpenIDAuth + +async def verify_google_token(token: str, *, audience: str) -> User: + payload = jwt.decode(token, google_public_keys, algorithms=["RS256"], + audience=audience) + return User(email=payload["email"], name=payload["name"]) + +google_auth = OpenIDAuth( + "https://accounts.google.com/.well-known/openid-configuration", + verify_google_token, + audience="my-client-id", +) + +@app.get("/me") +async def me(user: User = Security(google_auth)): + return user +``` + +The discovery URL is used **only for OpenAPI documentation** — no requests are made +to it by this class. You are responsible for fetching and caching the provider's +public keys in your validator. + +Multiple providers work naturally with `MultiAuth`: + +```python +multi = MultiAuth(google_auth, github_auth) + +@app.get("/data") +async def data(user: User = Security(multi)): + return user +``` + +## Typed validator kwargs + +All auth classes forward extra instantiation keyword arguments to the validator. +Arguments can be any type — enums, strings, integers, etc. The validator returns +the authenticated identity, which FastAPI injects directly into the route handler. + +```python +async def verify_token(token: str, *, role: Role, permission: str) -> User: + user = await decode_token(token) + if user.role != role or permission not in user.permissions: + raise UnauthorizedError() + return user + +bearer = BearerTokenAuth(verify_token, role=Role.ADMIN, permission="billing:read") +``` + +Each auth instance is self-contained — create a separate instance per distinct +requirement instead of passing requirements through `Security(scopes=[...])`. + +### Using `.require()` inline + +If declaring a new top-level variable per role feels verbose, use `.require()` to +create a configured clone directly in the route decorator. The original instance +is not mutated: + +```python +bearer = BearerTokenAuth(verify_token) + +@app.get("/admin/stats") +async def admin_stats(user: User = Security(bearer.require(role=Role.ADMIN))): + return {"message": f"Hello admin {user.name}"} + +@app.get("/profile") +async def profile(user: User = Security(bearer.require(role=Role.USER))): + return {"id": user.id, "name": user.name} +``` + +`.require()` kwargs are merged over existing ones — new values win on conflict. +The `prefix` (for `BearerTokenAuth`) and cookie name (for `CookieAuth`) are +always preserved. + +`.require()` instances work transparently inside `MultiAuth`: + +```python +multi = MultiAuth( + user_bearer.require(role=Role.USER), + org_bearer.require(role=Role.ADMIN), +) +``` + +## MultiAuth + +[`MultiAuth`](../reference/security.md#fastapi_toolsets.security.MultiAuth) combines +multiple auth sources into a single callable. Sources are tried in order; the +first one that finds a credential wins. + +```python +from fastapi_toolsets.security import MultiAuth + +multi = MultiAuth(user_bearer, org_bearer, cookie_auth) + +@app.get("/data") +async def data_route(user = Security(multi)): + return user +``` + +### Using `.require()` on MultiAuth + +`MultiAuth` also supports `.require()`, which propagates the kwargs to every +source that implements it. Sources that do not (e.g. custom `AuthSource` +subclasses) are passed through unchanged: + +```python +multi = MultiAuth(bearer, cookie) + +@app.get("/admin") +async def admin(user: User = Security(multi.require(role=Role.ADMIN))): + return user +``` + +This is equivalent to calling `.require()` on each source individually: + +```python +# These two are identical +multi.require(role=Role.ADMIN) + +MultiAuth( + bearer.require(role=Role.ADMIN), + cookie.require(role=Role.ADMIN), +) +``` + +### Prefix-based dispatch + +Because `extract()` is pure string matching (no I/O), prefix-based source +selection is essentially free. Only the matching source's validator (which may +involve DB or network I/O) is ever called: + +```python +user_bearer = BearerTokenAuth(verify_user, prefix="user_") +org_bearer = BearerTokenAuth(verify_org, prefix="org_") + +multi = MultiAuth(user_bearer, org_bearer) + +# "Bearer user_alice" → only verify_user runs, receives "user_alice" +# "Bearer org_acme" → only verify_org runs, receives "org_acme" +``` + +Tokens are stored and compared **with their prefix** — use `generate_token()` on +each source to issue correctly-prefixed tokens: + +```python +user_token = user_bearer.generate_token() # "user_..." +org_token = org_bearer.generate_token() # "org_..." +``` + +--- + +[:material-api: API Reference](../reference/security.md) diff --git a/docs/reference/security.md b/docs/reference/security.md new file mode 100644 index 0000000..f38235d --- /dev/null +++ b/docs/reference/security.md @@ -0,0 +1,28 @@ +# `security` + +Here's the reference for the authentication helpers provided by the `security` module. + +You can import them directly from `fastapi_toolsets.security`: + +```python +from fastapi_toolsets.security import ( + AuthSource, + BearerTokenAuth, + CookieAuth, + OAuth2Auth, + OpenIDAuth, + MultiAuth, +) +``` + +## ::: fastapi_toolsets.security.AuthSource + +## ::: fastapi_toolsets.security.BearerTokenAuth + +## ::: fastapi_toolsets.security.CookieAuth + +## ::: fastapi_toolsets.security.OAuth2Auth + +## ::: fastapi_toolsets.security.OpenIDAuth + +## ::: fastapi_toolsets.security.MultiAuth diff --git a/src/fastapi_toolsets/security/__init__.py b/src/fastapi_toolsets/security/__init__.py new file mode 100644 index 0000000..3d5505d --- /dev/null +++ b/src/fastapi_toolsets/security/__init__.py @@ -0,0 +1,15 @@ +"""Authentication helpers for FastAPI using Security().""" + +from .abc import AuthSource +from .oauth import decode_oauth_state, encode_oauth_state +from .sources import APIKeyHeaderAuth, BearerTokenAuth, CookieAuth, MultiAuth + +__all__ = [ + "APIKeyHeaderAuth", + "AuthSource", + "BearerTokenAuth", + "CookieAuth", + "MultiAuth", + "decode_oauth_state", + "encode_oauth_state", +] diff --git a/src/fastapi_toolsets/security/abc.py b/src/fastapi_toolsets/security/abc.py new file mode 100644 index 0000000..6640bf8 --- /dev/null +++ b/src/fastapi_toolsets/security/abc.py @@ -0,0 +1,51 @@ +"""Abstract base class for authentication sources.""" + +import inspect +from abc import ABC, abstractmethod +from typing import Any, Callable + +from fastapi import Request +from fastapi.security import SecurityScopes + +from fastapi_toolsets.exceptions import UnauthorizedError + + +async def _call_validator( + validator: Callable[..., Any], *args: Any, **kwargs: Any +) -> Any: + """Call *validator* with *args* and *kwargs*, awaiting it if it is a coroutine function.""" + if inspect.iscoroutinefunction(validator): + return await validator(*args, **kwargs) + return validator(*args, **kwargs) + + +class AuthSource(ABC): + """Abstract base class for authentication sources.""" + + def __init__(self) -> None: + """Set up the default FastAPI dependency signature.""" + source = self + + async def _call( + request: Request, + security_scopes: SecurityScopes, # noqa: ARG001 + ) -> Any: + credential = await source.extract(request) + if credential is None: + raise UnauthorizedError() + return await source.authenticate(credential) + + self._call_fn: Callable[..., Any] = _call + self.__signature__ = inspect.signature(_call) + + @abstractmethod + async def extract(self, request: Request) -> str | None: + """Extract the raw credential from the request without validating.""" + + @abstractmethod + async def authenticate(self, credential: str) -> Any: + """Validate a credential and return the authenticated identity.""" + + async def __call__(self, **kwargs: Any) -> Any: + """FastAPI dependency dispatch.""" + return await self._call_fn(**kwargs) diff --git a/src/fastapi_toolsets/security/oauth.py b/src/fastapi_toolsets/security/oauth.py new file mode 100644 index 0000000..4e95281 --- /dev/null +++ b/src/fastapi_toolsets/security/oauth.py @@ -0,0 +1,24 @@ +"""OAuth 2.0 / OIDC helper utilities.""" + +import base64 + + +def encode_oauth_state(url: str) -> str: + """Base64url-encode a URL to embed as an OAuth ``state`` parameter.""" + return base64.urlsafe_b64encode(url.encode()).decode() + + +def decode_oauth_state(state: str | None, *, fallback: str) -> str: + """Decode a base64url OAuth ``state`` parameter. + + Handles missing padding (some providers strip ``=``). + Returns *fallback* if *state* is absent, the literal string ``"null"``, + or cannot be decoded. + """ + if not state or state == "null": + return fallback + try: + padded = state + "=" * (4 - len(state) % 4) + return base64.urlsafe_b64decode(padded).decode() + except Exception: + return fallback diff --git a/src/fastapi_toolsets/security/sources/__init__.py b/src/fastapi_toolsets/security/sources/__init__.py new file mode 100644 index 0000000..8f90c54 --- /dev/null +++ b/src/fastapi_toolsets/security/sources/__init__.py @@ -0,0 +1,8 @@ +"""Built-in authentication source implementations.""" + +from .header import APIKeyHeaderAuth +from .bearer import BearerTokenAuth +from .cookie import CookieAuth +from .multi import MultiAuth + +__all__ = ["APIKeyHeaderAuth", "BearerTokenAuth", "CookieAuth", "MultiAuth"] diff --git a/src/fastapi_toolsets/security/sources/bearer.py b/src/fastapi_toolsets/security/sources/bearer.py new file mode 100644 index 0000000..3e49a33 --- /dev/null +++ b/src/fastapi_toolsets/security/sources/bearer.py @@ -0,0 +1,122 @@ +"""Bearer token authentication source.""" + +import inspect +import secrets +from typing import Annotated, Any, Callable + +from fastapi import Depends +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer, SecurityScopes + +from fastapi_toolsets.exceptions import UnauthorizedError + +from ..abc import AuthSource, _call_validator + + +class BearerTokenAuth(AuthSource): + """Bearer token authentication source. + + Wraps :class:`fastapi.security.HTTPBearer` for OpenAPI documentation. + The validator is called as ``await validator(credential, **kwargs)`` + where ``kwargs`` are the extra keyword arguments provided at instantiation. + + Args: + validator: Sync or async callable that receives the credential and any + extra keyword arguments, and returns the authenticated identity + (e.g. a ``User`` model). Should raise + :class:`~fastapi_toolsets.exceptions.UnauthorizedError` on failure. + prefix: Optional token prefix (e.g. ``"user_"``). If set, only tokens + whose value starts with this prefix are matched. The prefix is + **kept** in the value passed to the validator — store and compare + tokens with their prefix included. Use :meth:`generate_token` to + create correctly-prefixed tokens. This enables multiple + ``BearerTokenAuth`` instances in the same app (e.g. ``"user_"`` + for user tokens, ``"org_"`` for org tokens). + **kwargs: Extra keyword arguments forwarded to the validator on every + call (e.g. ``role=Role.ADMIN``). + """ + + def __init__( + self, + validator: Callable[..., Any], + *, + prefix: str | None = None, + **kwargs: Any, + ) -> None: + self._validator = validator + self._prefix = prefix + self._kwargs = kwargs + self._scheme = HTTPBearer(auto_error=False) + + _scheme = self._scheme + _validator = validator + _kwargs = kwargs + _prefix = prefix + + async def _call( + security_scopes: SecurityScopes, # noqa: ARG001 + credentials: Annotated[ + HTTPAuthorizationCredentials | None, Depends(_scheme) + ] = None, + ) -> Any: + if credentials is None: + raise UnauthorizedError() + token = credentials.credentials + if _prefix is not None and not token.startswith(_prefix): + raise UnauthorizedError() + return await _call_validator(_validator, token, **_kwargs) + + self._call_fn = _call + self.__signature__ = inspect.signature(_call) + + async def extract(self, request: Any) -> str | None: + """Extract the raw credential from the request without validating. + + Returns ``None`` if no ``Authorization: Bearer`` header is present, + the token is empty, or the token does not match the configured prefix. + The prefix is included in the returned value. + """ + auth = request.headers.get("Authorization", "") + if not auth.startswith("Bearer "): + return None + token = auth[7:] + if not token: + return None + if self._prefix is not None and not token.startswith(self._prefix): + return None + return token + + async def authenticate(self, credential: str) -> Any: + """Validate a credential and return the identity. + + Calls ``await validator(credential, **kwargs)`` where ``kwargs`` are + the extra keyword arguments provided at instantiation. + """ + return await _call_validator(self._validator, credential, **self._kwargs) + + def require(self, **kwargs: Any) -> "BearerTokenAuth": + """Return a new instance with additional (or overriding) validator kwargs.""" + return BearerTokenAuth( + self._validator, + prefix=self._prefix, + **{**self._kwargs, **kwargs}, + ) + + def generate_token(self, nbytes: int = 32) -> str: + """Generate a secure random token for this auth source. + + Returns a URL-safe random token. If a prefix is configured it is + prepended — the returned value is what you store in your database + and return to the client as-is. + + Args: + nbytes: Number of random bytes before base64 encoding. The + resulting string is ``ceil(nbytes * 4 / 3)`` characters + (43 chars for the default 32 bytes). Defaults to 32. + + Returns: + A ready-to-use token string (e.g. ``"user_Xk3..."``). + """ + token = secrets.token_urlsafe(nbytes) + if self._prefix is not None: + return f"{self._prefix}{token}" + return token diff --git a/src/fastapi_toolsets/security/sources/cookie.py b/src/fastapi_toolsets/security/sources/cookie.py new file mode 100644 index 0000000..89eb2bf --- /dev/null +++ b/src/fastapi_toolsets/security/sources/cookie.py @@ -0,0 +1,142 @@ +"""Cookie-based authentication source.""" + +import base64 +import hashlib +import hmac +import inspect +import json +import time +from typing import Annotated, Any, Callable + +from fastapi import Depends, Request, Response +from fastapi.security import APIKeyCookie, SecurityScopes + +from fastapi_toolsets.exceptions import UnauthorizedError + +from ..abc import AuthSource, _call_validator + + +class CookieAuth(AuthSource): + """Cookie-based authentication source. + + Wraps :class:`fastapi.security.APIKeyCookie` for OpenAPI documentation. + Optionally signs the cookie with HMAC-SHA256 to provide stateless, tamper- + proof sessions without any database entry. + + Args: + name: Cookie name. + validator: Sync or async callable that receives the cookie value + (plain, after signature verification when ``secret_key`` is set) + and any extra keyword arguments, and returns the authenticated + identity. + secret_key: When provided, the cookie is HMAC-SHA256 signed. + :meth:`set_cookie` embeds an expiry and signs the payload; + :meth:`extract` verifies the signature and expiry before handing + the plain value to the validator. When ``None`` (default), the raw + cookie value is passed to the validator as-is. + ttl: Cookie lifetime in seconds (default 24 h). Only used when + ``secret_key`` is set. + **kwargs: Extra keyword arguments forwarded to the validator on every + call (e.g. ``role=Role.ADMIN``). + """ + + def __init__( + self, + name: str, + validator: Callable[..., Any], + *, + secret_key: str | None = None, + ttl: int = 86400, + **kwargs: Any, + ) -> None: + self._name = name + self._validator = validator + self._secret_key = secret_key + self._ttl = ttl + self._kwargs = kwargs + self._scheme = APIKeyCookie(name=name, auto_error=False) + + _scheme = self._scheme + _self = self + _kwargs = kwargs + + async def _call( + security_scopes: SecurityScopes, # noqa: ARG001 + value: Annotated[str | None, Depends(_scheme)] = None, + ) -> Any: + if value is None: + raise UnauthorizedError() + plain = _self._verify(value) + return await _call_validator(_self._validator, plain, **_kwargs) + + self._call_fn = _call + self.__signature__ = inspect.signature(_call) + + def _hmac(self, data: str) -> str: + assert self._secret_key is not None + return hmac.new( + self._secret_key.encode(), data.encode(), hashlib.sha256 + ).hexdigest() + + def _sign(self, value: str) -> str: + data = base64.urlsafe_b64encode( + json.dumps({"v": value, "exp": int(time.time()) + self._ttl}).encode() + ).decode() + return f"{data}.{self._hmac(data)}" + + def _verify(self, cookie_value: str) -> str: + """Return the plain value, verifying HMAC + expiry when signed.""" + if not self._secret_key: + return cookie_value + + try: + data, sig = cookie_value.rsplit(".", 1) + except ValueError: + raise UnauthorizedError() + + if not hmac.compare_digest(self._hmac(data), sig): + raise UnauthorizedError() + + try: + payload = json.loads(base64.urlsafe_b64decode(data)) + value: str = payload["v"] + exp: int = payload["exp"] + except Exception: + raise UnauthorizedError() + + if exp < int(time.time()): + raise UnauthorizedError() + + return value + + async def extract(self, request: Request) -> str | None: + return request.cookies.get(self._name) + + async def authenticate(self, credential: str) -> Any: + plain = self._verify(credential) + return await _call_validator(self._validator, plain, **self._kwargs) + + def require(self, **kwargs: Any) -> "CookieAuth": + """Return a new instance with additional (or overriding) validator kwargs.""" + return CookieAuth( + self._name, + self._validator, + secret_key=self._secret_key, + ttl=self._ttl, + **{**self._kwargs, **kwargs}, + ) + + def set_cookie(self, response: Response, value: str) -> None: + """Attach the cookie to *response*, signing it when ``secret_key`` is set.""" + cookie_value = self._sign(value) if self._secret_key else value + response.set_cookie( + self._name, + cookie_value, + httponly=True, + samesite="lax", + max_age=self._ttl, + ) + + def delete_cookie(self, response: Response) -> None: + """Clear the session cookie (logout).""" + response.delete_cookie(self._name, httponly=True, samesite="lax") diff --git a/src/fastapi_toolsets/security/sources/header.py b/src/fastapi_toolsets/security/sources/header.py new file mode 100644 index 0000000..067d5e2 --- /dev/null +++ b/src/fastapi_toolsets/security/sources/header.py @@ -0,0 +1,71 @@ +"""API key header authentication source.""" + +import inspect +from typing import Annotated, Any, Callable + +from fastapi import Depends, Request +from fastapi.security import APIKeyHeader, SecurityScopes + +from fastapi_toolsets.exceptions import UnauthorizedError + +from ..abc import AuthSource, _call_validator + + +class APIKeyHeaderAuth(AuthSource): + """API key header authentication source. + + Wraps :class:`fastapi.security.APIKeyHeader` for OpenAPI documentation. + The validator is called as ``await validator(api_key, **kwargs)`` + where ``kwargs`` are the extra keyword arguments provided at instantiation. + + Args: + name: HTTP header name that carries the API key (e.g. ``"X-API-Key"``). + validator: Sync or async callable that receives the API key and any + extra keyword arguments, and returns the authenticated identity. + Should raise :class:`~fastapi_toolsets.exceptions.UnauthorizedError` + on failure. + **kwargs: Extra keyword arguments forwarded to the validator on every + call (e.g. ``role=Role.ADMIN``). + """ + + def __init__( + self, + name: str, + validator: Callable[..., Any], + **kwargs: Any, + ) -> None: + self._name = name + self._validator = validator + self._kwargs = kwargs + self._scheme = APIKeyHeader(name=name, auto_error=False) + + _scheme = self._scheme + _validator = validator + _kwargs = kwargs + + async def _call( + security_scopes: SecurityScopes, # noqa: ARG001 + api_key: Annotated[str | None, Depends(_scheme)] = None, + ) -> Any: + if api_key is None: + raise UnauthorizedError() + return await _call_validator(_validator, api_key, **_kwargs) + + self._call_fn = _call + self.__signature__ = inspect.signature(_call) + + async def extract(self, request: Request) -> str | None: + """Extract the API key from the configured header.""" + return request.headers.get(self._name) or None + + async def authenticate(self, credential: str) -> Any: + """Validate a credential and return the identity.""" + return await _call_validator(self._validator, credential, **self._kwargs) + + def require(self, **kwargs: Any) -> "APIKeyHeaderAuth": + """Return a new instance with additional (or overriding) validator kwargs.""" + return APIKeyHeaderAuth( + self._name, + self._validator, + **{**self._kwargs, **kwargs}, + ) diff --git a/src/fastapi_toolsets/security/sources/multi.py b/src/fastapi_toolsets/security/sources/multi.py new file mode 100644 index 0000000..78b6558 --- /dev/null +++ b/src/fastapi_toolsets/security/sources/multi.py @@ -0,0 +1,121 @@ +"""MultiAuth: combine multiple authentication sources into a single callable.""" + +import inspect +from typing import Any, cast + +from fastapi import Request +from fastapi.security import SecurityScopes + +from fastapi_toolsets.exceptions import UnauthorizedError + +from ..abc import AuthSource + + +class MultiAuth: + """Combine multiple authentication sources into a single callable. + + Sources are tried in order; the first one whose + :meth:`~AuthSource.extract` returns a non-``None`` credential wins. + Its :meth:`~AuthSource.authenticate` is called and the result returned. + + If a credential is found but the validator raises, the exception propagates + immediately — the remaining sources are **not** tried. This prevents + silent fallthrough on invalid credentials. + + If no source provides a credential, + :class:`~fastapi_toolsets.exceptions.UnauthorizedError` is raised. + + The :meth:`~AuthSource.extract` method of each source performs only + string matching (no I/O), so prefix-based dispatch is essentially free. + + Any :class:`~AuthSource` subclass — including user-defined ones — can be + passed as a source. + + Args: + *sources: Auth source instances to try in order. + + Example:: + + user_bearer = BearerTokenAuth(verify_user, prefix="user_") + org_bearer = BearerTokenAuth(verify_org, prefix="org_") + cookie = CookieAuth("session", verify_session) + + multi = MultiAuth(user_bearer, org_bearer, cookie) + + @app.get("/data") + async def data_route(user = Security(multi)): + return user + + # Apply a shared requirement to all sources at once + @app.get("/admin") + async def admin_route(user = Security(multi.require(role=Role.ADMIN))): + return user + """ + + def __init__(self, *sources: AuthSource) -> None: + self._sources = sources + + _sources = sources + + async def _call( + request: Request, + security_scopes: SecurityScopes, # noqa: ARG001 + **kwargs: Any, # noqa: ARG001 — absorbs scheme values injected by FastAPI + ) -> Any: + for source in _sources: + credential = await source.extract(request) + if credential is not None: + return await source.authenticate(credential) + raise UnauthorizedError() + + self._call_fn = _call + + # Build a merged signature that includes the security-scheme Depends() + # parameters from every source so FastAPI registers them in OpenAPI docs. + seen: set[str] = {"request", "security_scopes"} + merged: list[inspect.Parameter] = [ + inspect.Parameter( + "request", + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=Request, + ), + inspect.Parameter( + "security_scopes", + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=SecurityScopes, + ), + ] + for i, source in enumerate(sources): + for name, param in inspect.signature(source).parameters.items(): + if name in seen: + continue + merged.append(param.replace(name=f"_s{i}_{name}")) + seen.add(name) + self.__signature__ = inspect.Signature(merged, return_annotation=Any) + + async def __call__(self, **kwargs: Any) -> Any: + return await self._call_fn(**kwargs) + + def require(self, **kwargs: Any) -> "MultiAuth": + """Return a new :class:`MultiAuth` with kwargs forwarded to each source. + + Calls ``.require(**kwargs)`` on every source that supports it. Sources + that do not implement ``.require()`` (e.g. custom :class:`~AuthSource` + subclasses) are passed through unchanged. + + New kwargs are merged over each source's existing kwargs — new values + win on conflict:: + + multi = MultiAuth(bearer, cookie) + + @app.get("/admin") + async def admin(user = Security(multi.require(role=Role.ADMIN))): + return user + """ + new_sources = tuple( + cast(Any, source).require(**kwargs) + if hasattr(source, "require") + else source + for source in self._sources + ) + return MultiAuth(*new_sources) diff --git a/tests/test_security.py b/tests/test_security.py new file mode 100644 index 0000000..2e5d600 --- /dev/null +++ b/tests/test_security.py @@ -0,0 +1,964 @@ +"""Tests for fastapi_toolsets.security.""" + +import pytest +from fastapi import FastAPI, Security +from fastapi.testclient import TestClient + +from fastapi_toolsets.exceptions import UnauthorizedError, init_exceptions_handlers +from fastapi_toolsets.security import ( + APIKeyHeaderAuth, + AuthSource, + BearerTokenAuth, + CookieAuth, + MultiAuth, +) + + +def _app(*routes_setup_fns): + """Build a minimal FastAPI test app with exception handlers.""" + app = FastAPI() + init_exceptions_handlers(app) + for fn in routes_setup_fns: + fn(app) + return app + + +VALID_TOKEN = "secret" +VALID_COOKIE = "session123" + + +async def simple_validator(credential: str) -> dict: + if credential != VALID_TOKEN: + raise UnauthorizedError() + return {"user": "alice"} + + +async def role_validator(credential: str, *, role: str) -> dict: + if credential != VALID_TOKEN: + raise UnauthorizedError() + return {"user": "alice", "role": role} + + +async def cookie_validator(value: str) -> dict: + if value != VALID_COOKIE: + raise UnauthorizedError() + return {"session": value} + + +class TestBearerTokenAuth: + def test_valid_token_returns_identity(self): + bearer = BearerTokenAuth(simple_validator) + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(bearer)): + return user + + client = TestClient(_app(setup)) + response = client.get("/me", headers={"Authorization": f"Bearer {VALID_TOKEN}"}) + assert response.status_code == 200 + assert response.json() == {"user": "alice"} + + def test_missing_header_returns_401(self): + bearer = BearerTokenAuth(simple_validator) + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(bearer)): + return user + + client = TestClient(_app(setup)) + response = client.get("/me") + assert response.status_code == 401 + + def test_invalid_token_returns_401(self): + bearer = BearerTokenAuth(simple_validator) + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(bearer)): + return user + + client = TestClient(_app(setup)) + response = client.get("/me", headers={"Authorization": "Bearer wrong"}) + assert response.status_code == 401 + + def test_kwargs_forwarded_to_validator(self): + bearer = BearerTokenAuth(role_validator, role="admin") + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(bearer)): + return user + + client = TestClient(_app(setup)) + response = client.get("/me", headers={"Authorization": f"Bearer {VALID_TOKEN}"}) + assert response.status_code == 200 + assert response.json() == {"user": "alice", "role": "admin"} + + def test_prefix_matching_passes_full_token(self): + """Token with matching prefix: full token (with prefix) is passed to validator.""" + received: list[str] = [] + + async def capturing_validator(credential: str) -> dict: + received.append(credential) + return {"user": "alice"} + + bearer = BearerTokenAuth(capturing_validator, prefix="user_") + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(bearer)): + return user + + client = TestClient(_app(setup)) + response = client.get("/me", headers={"Authorization": "Bearer user_abc123"}) + assert response.status_code == 200 + # Prefix is kept — validator receives the full token as stored in DB + assert received == ["user_abc123"] + + def test_prefix_mismatch_returns_401(self): + bearer = BearerTokenAuth(simple_validator, prefix="user_") + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(bearer)): + return user + + client = TestClient(_app(setup)) + response = client.get("/me", headers={"Authorization": "Bearer org_abc123"}) + assert response.status_code == 401 + + # --- extract() --- + + @pytest.mark.anyio + async def test_extract_no_header(self): + from starlette.requests import Request + + bearer = BearerTokenAuth(simple_validator) + scope = {"type": "http", "method": "GET", "path": "/", "headers": []} + request = Request(scope) + assert await bearer.extract(request) is None + + @pytest.mark.anyio + async def test_extract_empty_token(self): + from starlette.requests import Request + + bearer = BearerTokenAuth(simple_validator) + scope = { + "type": "http", + "method": "GET", + "path": "/", + "headers": [(b"authorization", b"Bearer ")], + } + request = Request(scope) + assert await bearer.extract(request) is None + + @pytest.mark.anyio + async def test_extract_no_prefix(self): + from starlette.requests import Request + + bearer = BearerTokenAuth(simple_validator) + scope = { + "type": "http", + "method": "GET", + "path": "/", + "headers": [(b"authorization", b"Bearer mytoken")], + } + request = Request(scope) + assert await bearer.extract(request) == "mytoken" + + @pytest.mark.anyio + async def test_extract_prefix_match(self): + from starlette.requests import Request + + bearer = BearerTokenAuth(simple_validator, prefix="user_") + scope = { + "type": "http", + "method": "GET", + "path": "/", + "headers": [(b"authorization", b"Bearer user_abc")], + } + request = Request(scope) + assert await bearer.extract(request) == "user_abc" + + @pytest.mark.anyio + async def test_extract_prefix_no_match(self): + from starlette.requests import Request + + bearer = BearerTokenAuth(simple_validator, prefix="user_") + scope = { + "type": "http", + "method": "GET", + "path": "/", + "headers": [(b"authorization", b"Bearer org_abc")], + } + request = Request(scope) + assert await bearer.extract(request) is None + + # --- generate_token() --- + + def test_generate_token_no_prefix(self): + bearer = BearerTokenAuth(simple_validator) + token = bearer.generate_token() + assert isinstance(token, str) + assert len(token) > 0 + + def test_generate_token_with_prefix(self): + bearer = BearerTokenAuth(simple_validator, prefix="user_") + token = bearer.generate_token() + assert token.startswith("user_") + + def test_generate_token_uniqueness(self): + bearer = BearerTokenAuth(simple_validator) + assert bearer.generate_token() != bearer.generate_token() + + def test_generate_token_is_valid_credential(self): + """A generated token (with prefix) is accepted by the same auth source.""" + stored: list[str] = [] + + async def storing_validator(credential: str) -> dict: + stored.append(credential) + return {"token": credential} + + bearer = BearerTokenAuth(storing_validator, prefix="user_") + token = bearer.generate_token() + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(bearer)): + return user + + client = TestClient(_app(setup)) + response = client.get("/me", headers={"Authorization": f"Bearer {token}"}) + assert response.status_code == 200 + assert stored == [token] + + +class TestCookieAuth: + def test_valid_cookie_returns_identity(self): + cookie_auth = CookieAuth("session", cookie_validator) + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(cookie_auth)): + return user + + client = TestClient(_app(setup)) + response = client.get("/me", cookies={"session": VALID_COOKIE}) + assert response.status_code == 200 + assert response.json() == {"session": VALID_COOKIE} + + def test_missing_cookie_returns_401(self): + cookie_auth = CookieAuth("session", cookie_validator) + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(cookie_auth)): + return user + + client = TestClient(_app(setup)) + response = client.get("/me") + assert response.status_code == 401 + + def test_invalid_cookie_returns_401(self): + cookie_auth = CookieAuth("session", cookie_validator) + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(cookie_auth)): + return user + + client = TestClient(_app(setup)) + response = client.get("/me", cookies={"session": "wrong"}) + assert response.status_code == 401 + + def test_kwargs_forwarded_to_validator(self): + async def session_validator(value: str, *, scope: str) -> dict: + if value != VALID_COOKIE: + raise UnauthorizedError() + return {"session": value, "scope": scope} + + cookie_auth = CookieAuth("session", session_validator, scope="read") + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(cookie_auth)): + return user + + client = TestClient(_app(setup)) + response = client.get("/me", cookies={"session": VALID_COOKIE}) + assert response.status_code == 200 + assert response.json() == {"session": VALID_COOKIE, "scope": "read"} + + @pytest.mark.anyio + async def test_extract_no_cookie(self): + from starlette.requests import Request + + auth = CookieAuth("session", cookie_validator) + scope = {"type": "http", "method": "GET", "path": "/", "headers": []} + request = Request(scope) + assert await auth.extract(request) is None + + @pytest.mark.anyio + async def test_extract_cookie_present(self): + from starlette.requests import Request + + auth = CookieAuth("session", cookie_validator) + scope = { + "type": "http", + "method": "GET", + "path": "/", + "headers": [(b"cookie", b"session=abc")], + } + request = Request(scope) + assert await auth.extract(request) == "abc" + + +class TestAPIKeyHeaderAuth: + def test_valid_key_returns_identity(self): + auth = APIKeyHeaderAuth("X-API-Key", simple_validator) + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(auth)): + return user + + client = TestClient(_app(setup)) + response = client.get("/me", headers={"X-API-Key": VALID_TOKEN}) + assert response.status_code == 200 + assert response.json() == {"user": "alice"} + + def test_missing_header_returns_401(self): + auth = APIKeyHeaderAuth("X-API-Key", simple_validator) + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(auth)): + return user + + client = TestClient(_app(setup)) + response = client.get("/me") + assert response.status_code == 401 + + def test_invalid_key_returns_401(self): + auth = APIKeyHeaderAuth("X-API-Key", simple_validator) + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(auth)): + return user + + client = TestClient(_app(setup)) + response = client.get("/me", headers={"X-API-Key": "wrong"}) + assert response.status_code == 401 + + def test_kwargs_forwarded_to_validator(self): + auth = APIKeyHeaderAuth("X-API-Key", role_validator, role="admin") + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(auth)): + return user + + client = TestClient(_app(setup)) + response = client.get("/me", headers={"X-API-Key": VALID_TOKEN}) + assert response.status_code == 200 + assert response.json() == {"user": "alice", "role": "admin"} + + def test_require_forwards_kwargs(self): + auth = APIKeyHeaderAuth("X-API-Key", role_validator) + + def setup(app: FastAPI): + @app.get("/admin") + async def admin(user=Security(auth.require(role="admin"))): + return user + + client = TestClient(_app(setup)) + response = client.get("/admin", headers={"X-API-Key": VALID_TOKEN}) + assert response.status_code == 200 + assert response.json() == {"user": "alice", "role": "admin"} + + def test_require_preserves_name(self): + auth = APIKeyHeaderAuth("X-API-Key", simple_validator) + derived = auth.require(role="admin") + assert derived._name == "X-API-Key" + + def test_require_does_not_mutate_original(self): + auth = APIKeyHeaderAuth("X-API-Key", role_validator, role="user") + auth.require(role="admin") + assert auth._kwargs == {"role": "user"} + + def test_in_multi_auth(self): + """APIKeyHeaderAuth.authenticate() is exercised inside MultiAuth.""" + bearer = BearerTokenAuth(simple_validator) + api_key = APIKeyHeaderAuth("X-API-Key", simple_validator) + multi = MultiAuth(bearer, api_key) + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(multi)): + return user + + client = TestClient(_app(setup)) + # No bearer → falls through to API key header + response = client.get("/me", headers={"X-API-Key": VALID_TOKEN}) + assert response.status_code == 200 + assert response.json() == {"user": "alice"} + + def test_is_auth_source(self): + auth = APIKeyHeaderAuth("X-API-Key", simple_validator) + assert isinstance(auth, AuthSource) + + @pytest.mark.anyio + async def test_extract_no_header(self): + from starlette.requests import Request + + auth = APIKeyHeaderAuth("X-API-Key", simple_validator) + scope = {"type": "http", "method": "GET", "path": "/", "headers": []} + request = Request(scope) + assert await auth.extract(request) is None + + @pytest.mark.anyio + async def test_extract_empty_header(self): + from starlette.requests import Request + + auth = APIKeyHeaderAuth("X-API-Key", simple_validator) + scope = { + "type": "http", + "method": "GET", + "path": "/", + "headers": [(b"x-api-key", b"")], + } + request = Request(scope) + assert await auth.extract(request) is None + + @pytest.mark.anyio + async def test_extract_key_present(self): + from starlette.requests import Request + + auth = APIKeyHeaderAuth("X-API-Key", simple_validator) + scope = { + "type": "http", + "method": "GET", + "path": "/", + "headers": [(b"x-api-key", b"mykey")], + } + request = Request(scope) + assert await auth.extract(request) == "mykey" + + +class TestMultiAuth: + def test_first_source_matches(self): + bearer = BearerTokenAuth(simple_validator) + cookie = CookieAuth("session", cookie_validator) + multi = MultiAuth(bearer, cookie) + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(multi)): + return user + + client = TestClient(_app(setup)) + response = client.get("/me", headers={"Authorization": f"Bearer {VALID_TOKEN}"}) + assert response.status_code == 200 + assert response.json() == {"user": "alice"} + + def test_second_source_matches_when_first_absent(self): + bearer = BearerTokenAuth(simple_validator) + cookie = CookieAuth("session", cookie_validator) + multi = MultiAuth(bearer, cookie) + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(multi)): + return user + + client = TestClient(_app(setup)) + # No Authorization header — falls through to cookie + response = client.get("/me", cookies={"session": VALID_COOKIE}) + assert response.status_code == 200 + assert response.json() == {"session": VALID_COOKIE} + + def test_no_source_matches_returns_401(self): + bearer = BearerTokenAuth(simple_validator) + cookie = CookieAuth("session", cookie_validator) + multi = MultiAuth(bearer, cookie) + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(multi)): + return user + + client = TestClient(_app(setup)) + response = client.get("/me") + assert response.status_code == 401 + + def test_invalid_credential_does_not_fallthrough(self): + """If a credential is found but invalid, the next source is NOT tried.""" + second_called: list[bool] = [] + + async def tracking_validator(credential: str) -> dict: + second_called.append(True) + return {"from": "second"} + + bearer = BearerTokenAuth(simple_validator) # raises on wrong token + cookie = CookieAuth("session", tracking_validator) + multi = MultiAuth(bearer, cookie) + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(multi)): + return user + + client = TestClient(_app(setup)) + # Bearer credential present but wrong — should NOT try cookie + response = client.get( + "/me", + headers={"Authorization": "Bearer wrong"}, + cookies={"session": VALID_COOKIE}, + ) + assert response.status_code == 401 + assert second_called == [] # cookie validator was never called + + def test_prefix_routes_to_correct_source(self): + """Prefix-based dispatch: only the matching source's validator is called.""" + user_calls: list[str] = [] + org_calls: list[str] = [] + + async def user_validator(credential: str) -> dict: + user_calls.append(credential) + return {"type": "user", "id": credential} + + async def org_validator(credential: str) -> dict: + org_calls.append(credential) + return {"type": "org", "id": credential} + + user_bearer = BearerTokenAuth(user_validator, prefix="user_") + org_bearer = BearerTokenAuth(org_validator, prefix="org_") + multi = MultiAuth(user_bearer, org_bearer) + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(multi)): + return user + + client = TestClient(_app(setup)) + + response = client.get("/me", headers={"Authorization": "Bearer user_alice"}) + assert response.status_code == 200 + assert response.json() == {"type": "user", "id": "user_alice"} + assert user_calls == ["user_alice"] + assert org_calls == [] + + user_calls.clear() + + response = client.get("/me", headers={"Authorization": "Bearer org_acme"}) + assert response.status_code == 200 + assert response.json() == {"type": "org", "id": "org_acme"} + assert user_calls == [] + assert org_calls == ["org_acme"] + + def test_require_returns_new_multi_auth(self): + from fastapi_toolsets.security.sources import MultiAuth as MultiAuthClass + + bearer = BearerTokenAuth(role_validator) + multi = MultiAuth(bearer) + derived = multi.require(role="admin") + assert isinstance(derived, MultiAuthClass) + assert derived is not multi + + def test_require_forwards_kwargs_to_sources(self): + """multi.require() propagates to all sources that support it.""" + bearer = BearerTokenAuth(role_validator) + multi = MultiAuth(bearer) + + def setup(app: FastAPI): + @app.get("/admin") + async def admin(user=Security(multi.require(role="admin"))): + return user + + client = TestClient(_app(setup)) + response = client.get( + "/admin", headers={"Authorization": f"Bearer {VALID_TOKEN}"} + ) + assert response.status_code == 200 + assert response.json() == {"user": "alice", "role": "admin"} + + def test_require_skips_sources_without_require(self): + """Sources without require() are passed through unchanged.""" + header_auth = _HeaderAuth(secret="s3cr3t") + multi = MultiAuth(header_auth) + derived = multi.require(role="admin") + assert derived._sources[0] is header_auth + + def test_require_does_not_mutate_original(self): + bearer = BearerTokenAuth(role_validator, role="user") + multi = MultiAuth(bearer) + multi.require(role="admin") + assert bearer._kwargs == {"role": "user"} + + def test_require_mixed_sources(self): + """require() applies to sources with require(), skips those without.""" + from typing import cast + + bearer = BearerTokenAuth(role_validator) + header_auth = _HeaderAuth(secret="s3cr3t") + multi = MultiAuth(bearer, header_auth) + derived = multi.require(role="admin") + # bearer got require() applied, header_auth passed through + assert cast(BearerTokenAuth, derived._sources[0])._kwargs == {"role": "admin"} + assert derived._sources[1] is header_auth + + +class TestRequire: + def test_bearer_require_forwards_kwargs(self): + """require() creates a new instance that passes merged kwargs to validator.""" + bearer = BearerTokenAuth(role_validator) + + def setup(app: FastAPI): + @app.get("/admin") + async def admin(user=Security(bearer.require(role="admin"))): + return user + + client = TestClient(_app(setup)) + response = client.get( + "/admin", headers={"Authorization": f"Bearer {VALID_TOKEN}"} + ) + assert response.status_code == 200 + assert response.json() == {"user": "alice", "role": "admin"} + + def test_bearer_require_overrides_existing_kwarg(self): + """require() kwargs override kwargs set at instantiation.""" + bearer = BearerTokenAuth(role_validator, role="user") + + def setup(app: FastAPI): + @app.get("/admin") + async def admin(user=Security(bearer.require(role="admin"))): + return user + + client = TestClient(_app(setup)) + response = client.get( + "/admin", headers={"Authorization": f"Bearer {VALID_TOKEN}"} + ) + assert response.status_code == 200 + assert response.json()["role"] == "admin" + + def test_bearer_require_preserves_prefix(self): + """require() keeps the prefix of the original instance.""" + bearer = BearerTokenAuth(role_validator, prefix="user_") + derived = bearer.require(role="admin") + assert derived._prefix == "user_" + + def test_bearer_require_does_not_mutate_original(self): + """require() returns a new instance — original kwargs are unchanged.""" + bearer = BearerTokenAuth(role_validator, role="user") + bearer.require(role="admin") + assert bearer._kwargs == {"role": "user"} + + def test_cookie_require_forwards_kwargs(self): + async def scoped_validator(value: str, *, scope: str) -> dict: + if value != VALID_COOKIE: + raise UnauthorizedError() + return {"session": value, "scope": scope} + + cookie = CookieAuth("session", scoped_validator) + + def setup(app: FastAPI): + @app.get("/admin") + async def admin(user=Security(cookie.require(scope="admin"))): + return user + + client = TestClient(_app(setup)) + response = client.get("/admin", cookies={"session": VALID_COOKIE}) + assert response.status_code == 200 + assert response.json() == {"session": VALID_COOKIE, "scope": "admin"} + + def test_cookie_require_preserves_name(self): + cookie = CookieAuth("session", cookie_validator) + derived = cookie.require(scope="admin") + assert derived._name == "session" + + def test_bearer_require_in_multi_auth(self): + """require() instances work seamlessly inside MultiAuth.""" + PREFIXED_TOKEN = f"user_{VALID_TOKEN}" + + async def prefixed_role_validator(credential: str, *, role: str) -> dict: + if credential != PREFIXED_TOKEN: + raise UnauthorizedError() + return {"user": "alice", "role": role} + + bearer = BearerTokenAuth(prefixed_role_validator, prefix="user_") + multi = MultiAuth(bearer.require(role="admin")) + + def setup(app: FastAPI): + @app.get("/admin") + async def admin(user=Security(multi)): + return user + + client = TestClient(_app(setup)) + response = client.get( + "/admin", headers={"Authorization": f"Bearer {PREFIXED_TOKEN}"} + ) + assert response.status_code == 200 + assert response.json() == {"user": "alice", "role": "admin"} + + +class TestSyncValidators: + """Sync (non-async) validators — covers the sync path in _call_validator.""" + + def test_bearer_sync_validator(self): + def sync_validator(credential: str) -> dict: + if credential != VALID_TOKEN: + raise UnauthorizedError() + return {"user": "alice"} + + bearer = BearerTokenAuth(sync_validator) + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(bearer)): + return user + + client = TestClient(_app(setup)) + response = client.get("/me", headers={"Authorization": f"Bearer {VALID_TOKEN}"}) + assert response.status_code == 200 + assert response.json() == {"user": "alice"} + + def test_sync_validator_via_authenticate(self): + """authenticate() with sync validator (MultiAuth path).""" + + def sync_validator(credential: str) -> dict: + if credential != VALID_TOKEN: + raise UnauthorizedError() + return {"user": "alice"} + + bearer = BearerTokenAuth(sync_validator) + multi = MultiAuth(bearer) + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(multi)): + return user + + client = TestClient(_app(setup)) + response = client.get("/me", headers={"Authorization": f"Bearer {VALID_TOKEN}"}) + assert response.status_code == 200 + assert response.json() == {"user": "alice"} + + +class TestCookieAuthSigned: + """CookieAuth with HMAC-SHA256 signed cookies (secret_key path).""" + + SECRET = "test-hmac-secret" + + def test_valid_signed_cookie_via_set_cookie(self): + """set_cookie signs the value; the signed cookie is verified on read.""" + from fastapi import Response + + auth = CookieAuth("session", cookie_validator, secret_key=self.SECRET) + + def setup(app: FastAPI): + @app.get("/login") + async def login(response: Response): + auth.set_cookie(response, VALID_COOKIE) + return {"ok": True} + + @app.get("/me") + async def me(user=Security(auth)): + return user + + with TestClient(_app(setup)) as client: + client.get("/login") + response = client.get("/me") + assert response.status_code == 200 + assert response.json() == {"session": VALID_COOKIE} + + def test_tampered_signature_returns_401(self): + """A cookie whose HMAC signature has been modified is rejected.""" + import base64 as _b64 + import json as _json + import time as _time + + auth = CookieAuth("session", cookie_validator, secret_key=self.SECRET) + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(auth)): + return user + + client = TestClient(_app(setup)) + data = _b64.urlsafe_b64encode( + _json.dumps({"v": VALID_COOKIE, "exp": int(_time.time()) + 9999}).encode() + ).decode() + response = client.get("/me", cookies={"session": f"{data}.invalidsig"}) + assert response.status_code == 401 + + def test_expired_signed_cookie_returns_401(self): + """A signed cookie past its expiry timestamp is rejected.""" + import base64 as _b64 + import hashlib as _hashlib + import hmac as _hmac + import json as _json + import time as _time + + auth = CookieAuth("session", cookie_validator, secret_key=self.SECRET) + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(auth)): + return user + + client = TestClient(_app(setup)) + data = _b64.urlsafe_b64encode( + _json.dumps({"v": VALID_COOKIE, "exp": int(_time.time()) - 1}).encode() + ).decode() + sig = _hmac.new( + self.SECRET.encode(), data.encode(), _hashlib.sha256 + ).hexdigest() + response = client.get("/me", cookies={"session": f"{data}.{sig}"}) + assert response.status_code == 401 + + def test_invalid_json_payload_returns_401(self): + """A signed cookie whose payload is not valid JSON is rejected.""" + import base64 as _b64 + import hashlib as _hashlib + import hmac as _hmac + + auth = CookieAuth("session", cookie_validator, secret_key=self.SECRET) + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(auth)): + return user + + client = TestClient(_app(setup)) + data = _b64.urlsafe_b64encode(b"not-valid-json").decode() + sig = _hmac.new( + self.SECRET.encode(), data.encode(), _hashlib.sha256 + ).hexdigest() + response = client.get("/me", cookies={"session": f"{data}.{sig}"}) + assert response.status_code == 401 + + def test_malformed_cookie_no_dot_returns_401(self): + """A signed cookie without the dot separator is rejected.""" + auth = CookieAuth("session", cookie_validator, secret_key=self.SECRET) + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(auth)): + return user + + client = TestClient(_app(setup)) + response = client.get("/me", cookies={"session": "nodothere"}) + assert response.status_code == 401 + + def test_set_cookie_without_secret(self): + """set_cookie without secret_key writes the raw value.""" + from starlette.responses import Response as StarletteResponse + + auth = CookieAuth("session", cookie_validator) + response = StarletteResponse() + auth.set_cookie(response, "rawvalue") + assert "session=rawvalue" in response.headers["set-cookie"] + + def test_delete_cookie(self): + """delete_cookie produces a Set-Cookie header that clears the session.""" + from starlette.responses import Response as StarletteResponse + + auth = CookieAuth("session", cookie_validator) + response = StarletteResponse() + auth.delete_cookie(response) + assert "session" in response.headers["set-cookie"] + + +# Minimal concrete subclass used only in tests below. +class _HeaderAuth(AuthSource): + """Reads a custom X-Token header — no FastAPI security scheme.""" + + def __init__(self, secret: str) -> None: + super().__init__() + self._secret = secret + + async def extract(self, request) -> str | None: + return request.headers.get("X-Token") or None + + async def authenticate(self, credential: str) -> dict: + if credential != self._secret: + raise UnauthorizedError() + return {"token": credential} + + +class TestAuthSource: + def test_cannot_instantiate_abstract_class(self): + with pytest.raises(TypeError): + AuthSource() + + def test_builtin_classes_are_auth_sources(self): + bearer = BearerTokenAuth(simple_validator) + cookie = CookieAuth("session", cookie_validator) + api_key = APIKeyHeaderAuth("X-API-Key", simple_validator) + assert isinstance(bearer, AuthSource) + assert isinstance(cookie, AuthSource) + assert isinstance(api_key, AuthSource) + + def test_custom_source_standalone_valid(self): + """Default __call__ wires extract + authenticate via Request injection.""" + auth = _HeaderAuth(secret="s3cr3t") + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(auth)): + return user + + client = TestClient(_app(setup)) + response = client.get("/me", headers={"X-Token": "s3cr3t"}) + assert response.status_code == 200 + assert response.json() == {"token": "s3cr3t"} + + def test_custom_source_standalone_missing_credential(self): + auth = _HeaderAuth(secret="s3cr3t") + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(auth)): + return user + + client = TestClient(_app(setup)) + response = client.get("/me") # no X-Token header + assert response.status_code == 401 + + def test_custom_source_standalone_invalid_credential(self): + auth = _HeaderAuth(secret="s3cr3t") + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(auth)): + return user + + client = TestClient(_app(setup)) + response = client.get("/me", headers={"X-Token": "wrong"}) + assert response.status_code == 401 + + def test_custom_source_in_multi_auth(self): + """Custom AuthSource works transparently inside MultiAuth.""" + header_auth = _HeaderAuth(secret="s3cr3t") + bearer = BearerTokenAuth(simple_validator) + multi = MultiAuth(bearer, header_auth) + + def setup(app: FastAPI): + @app.get("/me") + async def me(user=Security(multi)): + return user + + client = TestClient(_app(setup)) + + # Bearer matches first + response = client.get("/me", headers={"Authorization": f"Bearer {VALID_TOKEN}"}) + assert response.status_code == 200 + assert response.json() == {"user": "alice"} + + # No bearer → falls through to custom header source + response = client.get("/me", headers={"X-Token": "s3cr3t"}) + assert response.status_code == 200 + assert response.json() == {"token": "s3cr3t"}