mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-16 14:46:24 +02:00
Compare commits
3 Commits
v2.3.0
...
5a1493266e
| Author | SHA1 | Date | |
|---|---|---|---|
|
5a1493266e
|
|||
|
83c1f98d25
|
|||
|
0bc025b844
|
1
docs/examples/authentication.md
Normal file
1
docs/examples/authentication.md
Normal file
@@ -0,0 +1 @@
|
||||
# Authentication
|
||||
267
docs/module/security.md
Normal file
267
docs/module/security.md
Normal file
@@ -0,0 +1,267 @@
|
||||
# Security
|
||||
|
||||
Composable authentication helpers for FastAPI that use `Security()` for OpenAPI documentation and accept user-provided validator functions with full type flexibility.
|
||||
|
||||
## Overview
|
||||
|
||||
The `security` module provides four auth source classes and a `MultiAuth` factory. Each class wraps a FastAPI security scheme for OpenAPI and accepts a validator function called as:
|
||||
|
||||
```python
|
||||
await validator(credential, **kwargs)
|
||||
```
|
||||
|
||||
where `kwargs` are the extra keyword arguments provided at instantiation (roles, permissions, enums, etc.). The validator returns the authenticated identity (e.g. a `User` model) which becomes the route dependency value.
|
||||
|
||||
```python
|
||||
from fastapi import Security
|
||||
from fastapi_toolsets.security import BearerTokenAuth
|
||||
|
||||
async def verify_token(token: str, *, role: str) -> User:
|
||||
user = await db.get_by_token(token)
|
||||
if not user or user.role != role:
|
||||
raise UnauthorizedError()
|
||||
return user
|
||||
|
||||
bearer_admin = BearerTokenAuth(verify_token, role="admin")
|
||||
|
||||
@app.get("/admin")
|
||||
async def admin_route(user: User = Security(bearer_admin)):
|
||||
return user
|
||||
```
|
||||
|
||||
## Auth sources
|
||||
|
||||
### [`BearerTokenAuth`](../reference/security.md#fastapi_toolsets.security.BearerTokenAuth)
|
||||
|
||||
Reads the `Authorization: Bearer <token>` header. Wraps `HTTPBearer` for OpenAPI.
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.security import BearerTokenAuth
|
||||
|
||||
bearer = BearerTokenAuth(validator=verify_token)
|
||||
|
||||
@app.get("/me")
|
||||
async def me(user: User = Security(bearer)):
|
||||
return user
|
||||
```
|
||||
|
||||
#### Token prefix
|
||||
|
||||
The optional `prefix` parameter restricts a `BearerTokenAuth` instance to tokens
|
||||
that start with a given string. The prefix is **kept** in the value passed to the
|
||||
validator — store and compare tokens with their prefix included.
|
||||
|
||||
This lets you deploy multiple `BearerTokenAuth` instances in the same application
|
||||
and disambiguate them efficiently in `MultiAuth`:
|
||||
|
||||
```python
|
||||
user_bearer = BearerTokenAuth(verify_user, prefix="user_") # matches "Bearer user_..."
|
||||
org_bearer = BearerTokenAuth(verify_org, prefix="org_") # matches "Bearer org_..."
|
||||
```
|
||||
|
||||
Use [`generate_token()`](#token-generation) to create correctly-prefixed tokens.
|
||||
|
||||
#### Token generation
|
||||
|
||||
`BearerTokenAuth.generate_token()` produces a secure random token ready to store
|
||||
in your database and return to the client. If a prefix is configured it is
|
||||
prepended automatically:
|
||||
|
||||
```python
|
||||
bearer = BearerTokenAuth(verify_token, prefix="user_")
|
||||
|
||||
token = bearer.generate_token() # e.g. "user_Xk3mN..."
|
||||
await db.store_token(user_id, token)
|
||||
return {"access_token": token, "token_type": "bearer"}
|
||||
```
|
||||
|
||||
The client sends `Authorization: Bearer user_Xk3mN...` and the validator receives
|
||||
the full token (prefix included) to compare against the stored value.
|
||||
|
||||
### [`CookieAuth`](../reference/security.md#fastapi_toolsets.security.CookieAuth)
|
||||
|
||||
Reads a named cookie. Wraps `APIKeyCookie` for OpenAPI.
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.security import CookieAuth
|
||||
|
||||
cookie_auth = CookieAuth("session", validator=verify_session)
|
||||
|
||||
@app.get("/me")
|
||||
async def me(user: User = Security(cookie_auth)):
|
||||
return user
|
||||
```
|
||||
|
||||
### [`OAuth2Auth`](../reference/security.md#fastapi_toolsets.security.OAuth2Auth)
|
||||
|
||||
Reads the `Authorization: Bearer <token>` header and registers the token endpoint
|
||||
in OpenAPI via `OAuth2PasswordBearer`.
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.security import OAuth2Auth
|
||||
|
||||
oauth2_auth = OAuth2Auth(token_url="/token", validator=verify_token)
|
||||
|
||||
@app.get("/me")
|
||||
async def me(user: User = Security(oauth2_auth)):
|
||||
return user
|
||||
```
|
||||
|
||||
### [`OpenIDAuth`](../reference/security.md#fastapi_toolsets.security.OpenIDAuth)
|
||||
|
||||
Reads the `Authorization: Bearer <token>` header and registers the OpenID Connect
|
||||
discovery URL in OpenAPI via `OpenIdConnect`. Token validation is fully delegated
|
||||
to your validator — use any OIDC / JWT library (`authlib`, `python-jose`, `PyJWT`).
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.security import OpenIDAuth
|
||||
|
||||
async def verify_google_token(token: str, *, audience: str) -> User:
|
||||
payload = jwt.decode(token, google_public_keys, algorithms=["RS256"],
|
||||
audience=audience)
|
||||
return User(email=payload["email"], name=payload["name"])
|
||||
|
||||
google_auth = OpenIDAuth(
|
||||
"https://accounts.google.com/.well-known/openid-configuration",
|
||||
verify_google_token,
|
||||
audience="my-client-id",
|
||||
)
|
||||
|
||||
@app.get("/me")
|
||||
async def me(user: User = Security(google_auth)):
|
||||
return user
|
||||
```
|
||||
|
||||
The discovery URL is used **only for OpenAPI documentation** — no requests are made
|
||||
to it by this class. You are responsible for fetching and caching the provider's
|
||||
public keys in your validator.
|
||||
|
||||
Multiple providers work naturally with `MultiAuth`:
|
||||
|
||||
```python
|
||||
multi = MultiAuth(google_auth, github_auth)
|
||||
|
||||
@app.get("/data")
|
||||
async def data(user: User = Security(multi)):
|
||||
return user
|
||||
```
|
||||
|
||||
## Typed validator kwargs
|
||||
|
||||
All auth classes forward extra instantiation keyword arguments to the validator.
|
||||
Arguments can be any type — enums, strings, integers, etc. The validator returns
|
||||
the authenticated identity, which FastAPI injects directly into the route handler.
|
||||
|
||||
```python
|
||||
async def verify_token(token: str, *, role: Role, permission: str) -> User:
|
||||
user = await decode_token(token)
|
||||
if user.role != role or permission not in user.permissions:
|
||||
raise UnauthorizedError()
|
||||
return user
|
||||
|
||||
bearer = BearerTokenAuth(verify_token, role=Role.ADMIN, permission="billing:read")
|
||||
```
|
||||
|
||||
Each auth instance is self-contained — create a separate instance per distinct
|
||||
requirement instead of passing requirements through `Security(scopes=[...])`.
|
||||
|
||||
### Using `.require()` inline
|
||||
|
||||
If declaring a new top-level variable per role feels verbose, use `.require()` to
|
||||
create a configured clone directly in the route decorator. The original instance
|
||||
is not mutated:
|
||||
|
||||
```python
|
||||
bearer = BearerTokenAuth(verify_token)
|
||||
|
||||
@app.get("/admin/stats")
|
||||
async def admin_stats(user: User = Security(bearer.require(role=Role.ADMIN))):
|
||||
return {"message": f"Hello admin {user.name}"}
|
||||
|
||||
@app.get("/profile")
|
||||
async def profile(user: User = Security(bearer.require(role=Role.USER))):
|
||||
return {"id": user.id, "name": user.name}
|
||||
```
|
||||
|
||||
`.require()` kwargs are merged over existing ones — new values win on conflict.
|
||||
The `prefix` (for `BearerTokenAuth`) and cookie name (for `CookieAuth`) are
|
||||
always preserved.
|
||||
|
||||
`.require()` instances work transparently inside `MultiAuth`:
|
||||
|
||||
```python
|
||||
multi = MultiAuth(
|
||||
user_bearer.require(role=Role.USER),
|
||||
org_bearer.require(role=Role.ADMIN),
|
||||
)
|
||||
```
|
||||
|
||||
## MultiAuth
|
||||
|
||||
[`MultiAuth`](../reference/security.md#fastapi_toolsets.security.MultiAuth) combines
|
||||
multiple auth sources into a single callable. Sources are tried in order; the
|
||||
first one that finds a credential wins.
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.security import MultiAuth
|
||||
|
||||
multi = MultiAuth(user_bearer, org_bearer, cookie_auth)
|
||||
|
||||
@app.get("/data")
|
||||
async def data_route(user = Security(multi)):
|
||||
return user
|
||||
```
|
||||
|
||||
### Using `.require()` on MultiAuth
|
||||
|
||||
`MultiAuth` also supports `.require()`, which propagates the kwargs to every
|
||||
source that implements it. Sources that do not (e.g. custom `AuthSource`
|
||||
subclasses) are passed through unchanged:
|
||||
|
||||
```python
|
||||
multi = MultiAuth(bearer, cookie)
|
||||
|
||||
@app.get("/admin")
|
||||
async def admin(user: User = Security(multi.require(role=Role.ADMIN))):
|
||||
return user
|
||||
```
|
||||
|
||||
This is equivalent to calling `.require()` on each source individually:
|
||||
|
||||
```python
|
||||
# These two are identical
|
||||
multi.require(role=Role.ADMIN)
|
||||
|
||||
MultiAuth(
|
||||
bearer.require(role=Role.ADMIN),
|
||||
cookie.require(role=Role.ADMIN),
|
||||
)
|
||||
```
|
||||
|
||||
### Prefix-based dispatch
|
||||
|
||||
Because `extract()` is pure string matching (no I/O), prefix-based source
|
||||
selection is essentially free. Only the matching source's validator (which may
|
||||
involve DB or network I/O) is ever called:
|
||||
|
||||
```python
|
||||
user_bearer = BearerTokenAuth(verify_user, prefix="user_")
|
||||
org_bearer = BearerTokenAuth(verify_org, prefix="org_")
|
||||
|
||||
multi = MultiAuth(user_bearer, org_bearer)
|
||||
|
||||
# "Bearer user_alice" → only verify_user runs, receives "user_alice"
|
||||
# "Bearer org_acme" → only verify_org runs, receives "org_acme"
|
||||
```
|
||||
|
||||
Tokens are stored and compared **with their prefix** — use `generate_token()` on
|
||||
each source to issue correctly-prefixed tokens:
|
||||
|
||||
```python
|
||||
user_token = user_bearer.generate_token() # "user_..."
|
||||
org_token = org_bearer.generate_token() # "org_..."
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
[:material-api: API Reference](../reference/security.md)
|
||||
28
docs/reference/security.md
Normal file
28
docs/reference/security.md
Normal file
@@ -0,0 +1,28 @@
|
||||
# `security`
|
||||
|
||||
Here's the reference for the authentication helpers provided by the `security` module.
|
||||
|
||||
You can import them directly from `fastapi_toolsets.security`:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.security import (
|
||||
AuthSource,
|
||||
BearerTokenAuth,
|
||||
CookieAuth,
|
||||
OAuth2Auth,
|
||||
OpenIDAuth,
|
||||
MultiAuth,
|
||||
)
|
||||
```
|
||||
|
||||
## ::: fastapi_toolsets.security.AuthSource
|
||||
|
||||
## ::: fastapi_toolsets.security.BearerTokenAuth
|
||||
|
||||
## ::: fastapi_toolsets.security.CookieAuth
|
||||
|
||||
## ::: fastapi_toolsets.security.OAuth2Auth
|
||||
|
||||
## ::: fastapi_toolsets.security.OpenIDAuth
|
||||
|
||||
## ::: fastapi_toolsets.security.MultiAuth
|
||||
0
docs_src/examples/authentication/__init__.py
Normal file
0
docs_src/examples/authentication/__init__.py
Normal file
9
docs_src/examples/authentication/app.py
Normal file
9
docs_src/examples/authentication/app.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
||||
|
||||
from .routes import router
|
||||
|
||||
app = FastAPI()
|
||||
init_exceptions_handlers(app=app)
|
||||
app.include_router(router=router)
|
||||
9
docs_src/examples/authentication/crud.py
Normal file
9
docs_src/examples/authentication/crud.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from fastapi_toolsets.crud import CrudFactory
|
||||
|
||||
from .models import OAuthAccount, OAuthProvider, Team, User, UserToken
|
||||
|
||||
TeamCrud = CrudFactory(model=Team)
|
||||
UserCrud = CrudFactory(model=User)
|
||||
UserTokenCrud = CrudFactory(model=UserToken)
|
||||
OAuthProviderCrud = CrudFactory(model=OAuthProvider)
|
||||
OAuthAccountCrud = CrudFactory(model=OAuthAccount)
|
||||
15
docs_src/examples/authentication/db.py
Normal file
15
docs_src/examples/authentication/db.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
from fastapi_toolsets.db import create_db_context, create_db_dependency
|
||||
|
||||
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/postgres"
|
||||
|
||||
engine = create_async_engine(url=DATABASE_URL, future=True)
|
||||
async_session_maker = async_sessionmaker(bind=engine, expire_on_commit=False)
|
||||
|
||||
get_db = create_db_dependency(session_maker=async_session_maker)
|
||||
get_db_context = create_db_context(session_maker=async_session_maker)
|
||||
|
||||
|
||||
SessionDep = Depends(get_db)
|
||||
105
docs_src/examples/authentication/models.py
Normal file
105
docs_src/examples/authentication/models.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import enum
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
DateTime,
|
||||
Enum,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID as PG_UUID
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||
|
||||
from fastapi_toolsets.models import TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class Base(DeclarativeBase, UUIDMixin):
|
||||
type_annotation_map = {
|
||||
str: String(),
|
||||
int: Integer(),
|
||||
UUID: PG_UUID(as_uuid=True),
|
||||
datetime: DateTime(timezone=True),
|
||||
}
|
||||
|
||||
|
||||
class UserRole(enum.Enum):
|
||||
admin = "admin"
|
||||
moderator = "moderator"
|
||||
user = "user"
|
||||
|
||||
|
||||
class Team(Base, TimestampMixin):
|
||||
__tablename__ = "teams"
|
||||
|
||||
name: Mapped[str] = mapped_column(String, unique=True, index=True)
|
||||
users: Mapped[list["User"]] = relationship(back_populates="team")
|
||||
|
||||
|
||||
class User(Base, TimestampMixin):
|
||||
__tablename__ = "users"
|
||||
|
||||
username: Mapped[str] = mapped_column(String, unique=True, index=True)
|
||||
email: Mapped[str | None] = mapped_column(
|
||||
String, unique=True, index=True, nullable=True
|
||||
)
|
||||
hashed_password: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
role: Mapped[UserRole] = mapped_column(Enum(UserRole), default=UserRole.user)
|
||||
|
||||
team_id: Mapped[UUID | None] = mapped_column(ForeignKey("teams.id"), nullable=True)
|
||||
team: Mapped["Team | None"] = relationship(back_populates="users")
|
||||
oauth_accounts: Mapped[list["OAuthAccount"]] = relationship(back_populates="user")
|
||||
tokens: Mapped[list["UserToken"]] = relationship(back_populates="user")
|
||||
|
||||
|
||||
class UserToken(Base, TimestampMixin):
|
||||
"""API tokens for a user (multiple allowed)."""
|
||||
|
||||
__tablename__ = "user_tokens"
|
||||
|
||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("users.id"))
|
||||
# Store hashed token value
|
||||
token_hash: Mapped[str] = mapped_column(String, unique=True, index=True)
|
||||
name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
expires_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
user: Mapped["User"] = relationship(back_populates="tokens")
|
||||
|
||||
|
||||
class OAuthProvider(Base, TimestampMixin):
|
||||
"""Configurable OAuth2 / OpenID Connect provider."""
|
||||
|
||||
__tablename__ = "oauth_providers"
|
||||
|
||||
slug: Mapped[str] = mapped_column(String, unique=True, index=True)
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
client_id: Mapped[str] = mapped_column(String)
|
||||
client_secret: Mapped[str] = mapped_column(String)
|
||||
discovery_url: Mapped[str] = mapped_column(String, nullable=False)
|
||||
scopes: Mapped[str] = mapped_column(String, default="openid email profile")
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
|
||||
accounts: Mapped[list["OAuthAccount"]] = relationship(back_populates="provider")
|
||||
|
||||
|
||||
class OAuthAccount(Base, TimestampMixin):
|
||||
"""OAuth2 / OpenID Connect account linked to a user."""
|
||||
|
||||
__tablename__ = "oauth_accounts"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("provider_id", "subject", name="uq_oauth_provider_subject"),
|
||||
)
|
||||
|
||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("users.id"))
|
||||
provider_id: Mapped[UUID] = mapped_column(ForeignKey("oauth_providers.id"))
|
||||
# OAuth `sub` / OpenID subject identifier
|
||||
subject: Mapped[str] = mapped_column(String)
|
||||
|
||||
user: Mapped["User"] = relationship(back_populates="oauth_accounts")
|
||||
provider: Mapped["OAuthProvider"] = relationship(back_populates="accounts")
|
||||
122
docs_src/examples/authentication/routes.py
Normal file
122
docs_src/examples/authentication/routes.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
import bcrypt
|
||||
from fastapi import APIRouter, Form, HTTPException, Response, Security
|
||||
|
||||
from fastapi_toolsets.dependencies import PathDependency
|
||||
|
||||
from .crud import UserCrud, UserTokenCrud
|
||||
from .db import SessionDep
|
||||
from .models import OAuthProvider, User, UserToken
|
||||
from .schemas import (
|
||||
ApiTokenCreateRequest,
|
||||
ApiTokenResponse,
|
||||
RegisterRequest,
|
||||
UserCreate,
|
||||
UserResponse,
|
||||
)
|
||||
from .security import auth, cookie_auth, create_api_token
|
||||
|
||||
ProviderDep = PathDependency(
|
||||
model=OAuthProvider,
|
||||
field=OAuthProvider.slug,
|
||||
session_dep=SessionDep,
|
||||
param_name="slug",
|
||||
)
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
|
||||
def verify_password(plain: str, hashed: str) -> bool:
|
||||
return bcrypt.checkpw(plain.encode(), hashed.encode())
|
||||
|
||||
|
||||
router = APIRouter(prefix="/auth")
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=201)
|
||||
async def register(body: RegisterRequest, session: SessionDep):
|
||||
existing = await UserCrud.first(
|
||||
session=session, filters=[User.username == body.username]
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="Username already taken")
|
||||
|
||||
user = await UserCrud.create(
|
||||
session=session,
|
||||
obj=UserCreate(
|
||||
username=body.username,
|
||||
email=body.email,
|
||||
hashed_password=hash_password(body.password),
|
||||
),
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
@router.post("/token", status_code=204)
|
||||
async def login(
|
||||
session: SessionDep,
|
||||
response: Response,
|
||||
username: Annotated[str, Form()],
|
||||
password: Annotated[str, Form()],
|
||||
):
|
||||
user = await UserCrud.first(session=session, filters=[User.username == username])
|
||||
|
||||
if (
|
||||
not user
|
||||
or not user.hashed_password
|
||||
or not verify_password(password, user.hashed_password)
|
||||
):
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(status_code=403, detail="Account disabled")
|
||||
|
||||
cookie_auth.set_cookie(response, str(user.id))
|
||||
|
||||
|
||||
@router.post("/logout", status_code=204)
|
||||
async def logout(response: Response):
|
||||
cookie_auth.delete_cookie(response)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def me(user: User = Security(auth)):
|
||||
return user
|
||||
|
||||
|
||||
@router.post("/tokens", response_model=ApiTokenResponse, status_code=201)
|
||||
async def create_token(
|
||||
body: ApiTokenCreateRequest,
|
||||
user: User = Security(auth),
|
||||
):
|
||||
raw, token_row = await create_api_token(
|
||||
user.id, name=body.name, expires_at=body.expires_at
|
||||
)
|
||||
return ApiTokenResponse(
|
||||
id=token_row.id,
|
||||
name=token_row.name,
|
||||
expires_at=token_row.expires_at,
|
||||
created_at=token_row.created_at,
|
||||
token=raw,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/tokens/{token_id}", status_code=204)
|
||||
async def revoke_token(
|
||||
session: SessionDep,
|
||||
token_id: UUID,
|
||||
user: User = Security(auth),
|
||||
):
|
||||
if not await UserTokenCrud.first(
|
||||
session=session,
|
||||
filters=[UserToken.id == token_id, UserToken.user_id == user.id],
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Token not found")
|
||||
await UserTokenCrud.delete(
|
||||
session=session,
|
||||
filters=[UserToken.id == token_id, UserToken.user_id == user.id],
|
||||
)
|
||||
64
docs_src/examples/authentication/schemas.py
Normal file
64
docs_src/examples/authentication/schemas.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import EmailStr
|
||||
|
||||
from fastapi_toolsets.schemas import PydanticBase
|
||||
|
||||
|
||||
class RegisterRequest(PydanticBase):
|
||||
username: str
|
||||
password: str
|
||||
email: EmailStr | None = None
|
||||
|
||||
|
||||
class UserResponse(PydanticBase):
|
||||
id: UUID
|
||||
username: str
|
||||
email: str | None
|
||||
role: str
|
||||
is_active: bool
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ApiTokenCreateRequest(PydanticBase):
|
||||
name: str | None = None
|
||||
expires_at: datetime | None = None
|
||||
|
||||
|
||||
class ApiTokenResponse(PydanticBase):
|
||||
id: UUID
|
||||
name: str | None
|
||||
expires_at: datetime | None
|
||||
created_at: datetime
|
||||
# Only populated on creation
|
||||
token: str | None = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class OAuthProviderResponse(PydanticBase):
|
||||
slug: str
|
||||
name: str
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class UserCreate(PydanticBase):
|
||||
username: str
|
||||
email: str | None = None
|
||||
hashed_password: str | None = None
|
||||
|
||||
|
||||
class UserTokenCreate(PydanticBase):
|
||||
user_id: UUID
|
||||
token_hash: str
|
||||
name: str | None = None
|
||||
expires_at: datetime | None = None
|
||||
|
||||
|
||||
class OAuthAccountCreate(PydanticBase):
|
||||
user_id: UUID
|
||||
provider_id: UUID
|
||||
subject: str
|
||||
100
docs_src/examples/authentication/security.py
Normal file
100
docs_src/examples/authentication/security.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import hashlib
|
||||
from datetime import datetime, timezone
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from fastapi_toolsets.exceptions import UnauthorizedError
|
||||
from fastapi_toolsets.security import (
|
||||
APIKeyHeaderAuth,
|
||||
BearerTokenAuth,
|
||||
CookieAuth,
|
||||
MultiAuth,
|
||||
)
|
||||
|
||||
from .crud import UserCrud, UserTokenCrud
|
||||
from .db import get_db_context
|
||||
from .models import User, UserRole, UserToken
|
||||
from .schemas import UserTokenCreate
|
||||
|
||||
SESSION_COOKIE = "session"
|
||||
SECRET_KEY = "123456789"
|
||||
|
||||
|
||||
def _hash_token(token: str) -> str:
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
async def _verify_token(token: str, role: UserRole | None = None) -> User:
|
||||
async with get_db_context() as db:
|
||||
user_token = await UserTokenCrud.first(
|
||||
session=db,
|
||||
filters=[UserToken.token_hash == _hash_token(token)],
|
||||
load_options=[selectinload(UserToken.user)],
|
||||
)
|
||||
|
||||
if user_token is None or not user_token.user.is_active:
|
||||
raise UnauthorizedError()
|
||||
|
||||
if user_token.expires_at and user_token.expires_at < datetime.now(timezone.utc):
|
||||
raise UnauthorizedError()
|
||||
|
||||
user = user_token.user
|
||||
|
||||
if role is not None and user.role != role:
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def _verify_cookie(user_id: str, role: UserRole | None = None) -> User:
|
||||
async with get_db_context() as db:
|
||||
user = await UserCrud.first(
|
||||
session=db,
|
||||
filters=[User.id == UUID(user_id)],
|
||||
)
|
||||
|
||||
if not user or not user.is_active:
|
||||
raise UnauthorizedError()
|
||||
|
||||
if role is not None and user.role != role:
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
bearer_auth = BearerTokenAuth(
|
||||
validator=_verify_token,
|
||||
prefix="ctf_",
|
||||
)
|
||||
header_auth = APIKeyHeaderAuth(
|
||||
name="X-API-Key",
|
||||
validator=_verify_token,
|
||||
)
|
||||
cookie_auth = CookieAuth(
|
||||
name=SESSION_COOKIE,
|
||||
validator=_verify_cookie,
|
||||
secret_key=SECRET_KEY,
|
||||
)
|
||||
auth = MultiAuth(bearer_auth, header_auth, cookie_auth)
|
||||
|
||||
|
||||
async def create_api_token(
|
||||
user_id: UUID,
|
||||
*,
|
||||
name: str | None = None,
|
||||
expires_at: datetime | None = None,
|
||||
) -> tuple[str, UserToken]:
|
||||
raw = bearer_auth.generate_token()
|
||||
async with get_db_context() as db:
|
||||
token_row = await UserTokenCrud.create(
|
||||
session=db,
|
||||
obj=UserTokenCreate(
|
||||
user_id=user_id,
|
||||
token_hash=_hash_token(raw),
|
||||
name=name,
|
||||
expires_at=expires_at,
|
||||
),
|
||||
)
|
||||
return raw, token_row
|
||||
24
src/fastapi_toolsets/security/__init__.py
Normal file
24
src/fastapi_toolsets/security/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Authentication helpers for FastAPI using Security()."""
|
||||
|
||||
from .abc import AuthSource
|
||||
from .oauth import (
|
||||
oauth_build_authorization_redirect,
|
||||
oauth_decode_state,
|
||||
oauth_encode_state,
|
||||
oauth_fetch_userinfo,
|
||||
oauth_resolve_provider_urls,
|
||||
)
|
||||
from .sources import APIKeyHeaderAuth, BearerTokenAuth, CookieAuth, MultiAuth
|
||||
|
||||
__all__ = [
|
||||
"APIKeyHeaderAuth",
|
||||
"AuthSource",
|
||||
"BearerTokenAuth",
|
||||
"CookieAuth",
|
||||
"MultiAuth",
|
||||
"oauth_build_authorization_redirect",
|
||||
"oauth_decode_state",
|
||||
"oauth_encode_state",
|
||||
"oauth_fetch_userinfo",
|
||||
"oauth_resolve_provider_urls",
|
||||
]
|
||||
51
src/fastapi_toolsets/security/abc.py
Normal file
51
src/fastapi_toolsets/security/abc.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Abstract base class for authentication sources."""
|
||||
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.security import SecurityScopes
|
||||
|
||||
from fastapi_toolsets.exceptions import UnauthorizedError
|
||||
|
||||
|
||||
async def _call_validator(
|
||||
validator: Callable[..., Any], *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Call *validator* with *args* and *kwargs*, awaiting it if it is a coroutine function."""
|
||||
if inspect.iscoroutinefunction(validator):
|
||||
return await validator(*args, **kwargs)
|
||||
return validator(*args, **kwargs)
|
||||
|
||||
|
||||
class AuthSource(ABC):
|
||||
"""Abstract base class for authentication sources."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Set up the default FastAPI dependency signature."""
|
||||
source = self
|
||||
|
||||
async def _call(
|
||||
request: Request,
|
||||
security_scopes: SecurityScopes, # noqa: ARG001
|
||||
) -> Any:
|
||||
credential = await source.extract(request)
|
||||
if credential is None:
|
||||
raise UnauthorizedError()
|
||||
return await source.authenticate(credential)
|
||||
|
||||
self._call_fn: Callable[..., Any] = _call
|
||||
self.__signature__ = inspect.signature(_call)
|
||||
|
||||
@abstractmethod
|
||||
async def extract(self, request: Request) -> str | None:
|
||||
"""Extract the raw credential from the request without validating."""
|
||||
|
||||
@abstractmethod
|
||||
async def authenticate(self, credential: str) -> Any:
|
||||
"""Validate a credential and return the authenticated identity."""
|
||||
|
||||
async def __call__(self, **kwargs: Any) -> Any:
|
||||
"""FastAPI dependency dispatch."""
|
||||
return await self._call_fn(**kwargs)
|
||||
140
src/fastapi_toolsets/security/oauth.py
Normal file
140
src/fastapi_toolsets/security/oauth.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""OAuth 2.0 / OIDC helper utilities."""
|
||||
|
||||
import base64
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
_discovery_cache: dict[str, dict] = {}
|
||||
|
||||
|
||||
async def oauth_resolve_provider_urls(
|
||||
discovery_url: str,
|
||||
) -> tuple[str, str, str | None]:
|
||||
"""Fetch the OIDC discovery document and return endpoint URLs.
|
||||
|
||||
Args:
|
||||
discovery_url: URL of the provider's ``/.well-known/openid-configuration``.
|
||||
|
||||
Returns:
|
||||
A ``(authorization_url, token_url, userinfo_url)`` tuple.
|
||||
*userinfo_url* is ``None`` when the provider does not advertise one.
|
||||
"""
|
||||
if discovery_url not in _discovery_cache:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(discovery_url)
|
||||
resp.raise_for_status()
|
||||
_discovery_cache[discovery_url] = resp.json()
|
||||
cfg = _discovery_cache[discovery_url]
|
||||
return (
|
||||
cfg["authorization_endpoint"],
|
||||
cfg["token_endpoint"],
|
||||
cfg.get("userinfo_endpoint"),
|
||||
)
|
||||
|
||||
|
||||
async def oauth_fetch_userinfo(
|
||||
*,
|
||||
token_url: str,
|
||||
userinfo_url: str,
|
||||
code: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
redirect_uri: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Exchange an authorization code for tokens and return the userinfo payload.
|
||||
|
||||
Performs the two-step OAuth 2.0 / OIDC token exchange:
|
||||
|
||||
1. POSTs the authorization *code* to *token_url* to obtain an access token.
|
||||
2. GETs *userinfo_url* using that access token as a Bearer credential.
|
||||
|
||||
Args:
|
||||
token_url: Provider's token endpoint.
|
||||
userinfo_url: Provider's userinfo endpoint.
|
||||
code: Authorization code received from the provider's callback.
|
||||
client_id: OAuth application client ID.
|
||||
client_secret: OAuth application client secret.
|
||||
redirect_uri: Redirect URI that was used in the authorization request.
|
||||
|
||||
Returns:
|
||||
The JSON payload returned by the userinfo endpoint as a plain ``dict``.
|
||||
"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
token_resp = await client.post(
|
||||
token_url,
|
||||
data={
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
token_resp.raise_for_status()
|
||||
access_token = token_resp.json()["access_token"]
|
||||
|
||||
userinfo_resp = await client.get(
|
||||
userinfo_url,
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
userinfo_resp.raise_for_status()
|
||||
return userinfo_resp.json()
|
||||
|
||||
|
||||
def oauth_build_authorization_redirect(
|
||||
authorization_url: str,
|
||||
*,
|
||||
client_id: str,
|
||||
scopes: str,
|
||||
redirect_uri: str,
|
||||
destination: str,
|
||||
) -> RedirectResponse:
|
||||
"""Return an OAuth 2.0 authorization ``RedirectResponse``.
|
||||
|
||||
Args:
|
||||
authorization_url: Provider's authorization endpoint.
|
||||
client_id: OAuth application client ID.
|
||||
scopes: Space-separated list of requested scopes.
|
||||
redirect_uri: URI the provider should redirect back to after authorization.
|
||||
destination: URL the user should be sent to after the full OAuth flow
|
||||
completes (encoded as ``state``).
|
||||
|
||||
Returns:
|
||||
A :class:`~fastapi.responses.RedirectResponse` to the provider's
|
||||
authorization page.
|
||||
"""
|
||||
params = urlencode(
|
||||
{
|
||||
"client_id": client_id,
|
||||
"response_type": "code",
|
||||
"scope": scopes,
|
||||
"redirect_uri": redirect_uri,
|
||||
"state": oauth_encode_state(destination),
|
||||
}
|
||||
)
|
||||
return RedirectResponse(f"{authorization_url}?{params}")
|
||||
|
||||
|
||||
def oauth_encode_state(url: str) -> str:
|
||||
"""Base64url-encode a URL to embed as an OAuth ``state`` parameter."""
|
||||
return base64.urlsafe_b64encode(url.encode()).decode()
|
||||
|
||||
|
||||
def oauth_decode_state(state: str | None, *, fallback: str) -> str:
|
||||
"""Decode a base64url OAuth ``state`` parameter.
|
||||
|
||||
Handles missing padding (some providers strip ``=``).
|
||||
Returns *fallback* if *state* is absent, the literal string ``"null"``,
|
||||
or cannot be decoded.
|
||||
"""
|
||||
if not state or state == "null":
|
||||
return fallback
|
||||
try:
|
||||
padded = state + "=" * (4 - len(state) % 4)
|
||||
return base64.urlsafe_b64decode(padded).decode()
|
||||
except Exception:
|
||||
return fallback
|
||||
8
src/fastapi_toolsets/security/sources/__init__.py
Normal file
8
src/fastapi_toolsets/security/sources/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Built-in authentication source implementations."""
|
||||
|
||||
from .header import APIKeyHeaderAuth
|
||||
from .bearer import BearerTokenAuth
|
||||
from .cookie import CookieAuth
|
||||
from .multi import MultiAuth
|
||||
|
||||
__all__ = ["APIKeyHeaderAuth", "BearerTokenAuth", "CookieAuth", "MultiAuth"]
|
||||
122
src/fastapi_toolsets/security/sources/bearer.py
Normal file
122
src/fastapi_toolsets/security/sources/bearer.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Bearer token authentication source."""
|
||||
|
||||
import inspect
|
||||
import secrets
|
||||
from typing import Annotated, Any, Callable
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer, SecurityScopes
|
||||
|
||||
from fastapi_toolsets.exceptions import UnauthorizedError
|
||||
|
||||
from ..abc import AuthSource, _call_validator
|
||||
|
||||
|
||||
class BearerTokenAuth(AuthSource):
|
||||
"""Bearer token authentication source.
|
||||
|
||||
Wraps :class:`fastapi.security.HTTPBearer` for OpenAPI documentation.
|
||||
The validator is called as ``await validator(credential, **kwargs)``
|
||||
where ``kwargs`` are the extra keyword arguments provided at instantiation.
|
||||
|
||||
Args:
|
||||
validator: Sync or async callable that receives the credential and any
|
||||
extra keyword arguments, and returns the authenticated identity
|
||||
(e.g. a ``User`` model). Should raise
|
||||
:class:`~fastapi_toolsets.exceptions.UnauthorizedError` on failure.
|
||||
prefix: Optional token prefix (e.g. ``"user_"``). If set, only tokens
|
||||
whose value starts with this prefix are matched. The prefix is
|
||||
**kept** in the value passed to the validator — store and compare
|
||||
tokens with their prefix included. Use :meth:`generate_token` to
|
||||
create correctly-prefixed tokens. This enables multiple
|
||||
``BearerTokenAuth`` instances in the same app (e.g. ``"user_"``
|
||||
for user tokens, ``"org_"`` for org tokens).
|
||||
**kwargs: Extra keyword arguments forwarded to the validator on every
|
||||
call (e.g. ``role=Role.ADMIN``).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
validator: Callable[..., Any],
|
||||
*,
|
||||
prefix: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._validator = validator
|
||||
self._prefix = prefix
|
||||
self._kwargs = kwargs
|
||||
self._scheme = HTTPBearer(auto_error=False)
|
||||
|
||||
_scheme = self._scheme
|
||||
_validator = validator
|
||||
_kwargs = kwargs
|
||||
_prefix = prefix
|
||||
|
||||
async def _call(
|
||||
security_scopes: SecurityScopes, # noqa: ARG001
|
||||
credentials: Annotated[
|
||||
HTTPAuthorizationCredentials | None, Depends(_scheme)
|
||||
] = None,
|
||||
) -> Any:
|
||||
if credentials is None:
|
||||
raise UnauthorizedError()
|
||||
token = credentials.credentials
|
||||
if _prefix is not None and not token.startswith(_prefix):
|
||||
raise UnauthorizedError()
|
||||
return await _call_validator(_validator, token, **_kwargs)
|
||||
|
||||
self._call_fn = _call
|
||||
self.__signature__ = inspect.signature(_call)
|
||||
|
||||
async def extract(self, request: Any) -> str | None:
|
||||
"""Extract the raw credential from the request without validating.
|
||||
|
||||
Returns ``None`` if no ``Authorization: Bearer`` header is present,
|
||||
the token is empty, or the token does not match the configured prefix.
|
||||
The prefix is included in the returned value.
|
||||
"""
|
||||
auth = request.headers.get("Authorization", "")
|
||||
if not auth.startswith("Bearer "):
|
||||
return None
|
||||
token = auth[7:]
|
||||
if not token:
|
||||
return None
|
||||
if self._prefix is not None and not token.startswith(self._prefix):
|
||||
return None
|
||||
return token
|
||||
|
||||
async def authenticate(self, credential: str) -> Any:
|
||||
"""Validate a credential and return the identity.
|
||||
|
||||
Calls ``await validator(credential, **kwargs)`` where ``kwargs`` are
|
||||
the extra keyword arguments provided at instantiation.
|
||||
"""
|
||||
return await _call_validator(self._validator, credential, **self._kwargs)
|
||||
|
||||
def require(self, **kwargs: Any) -> "BearerTokenAuth":
|
||||
"""Return a new instance with additional (or overriding) validator kwargs."""
|
||||
return BearerTokenAuth(
|
||||
self._validator,
|
||||
prefix=self._prefix,
|
||||
**{**self._kwargs, **kwargs},
|
||||
)
|
||||
|
||||
def generate_token(self, nbytes: int = 32) -> str:
|
||||
"""Generate a secure random token for this auth source.
|
||||
|
||||
Returns a URL-safe random token. If a prefix is configured it is
|
||||
prepended — the returned value is what you store in your database
|
||||
and return to the client as-is.
|
||||
|
||||
Args:
|
||||
nbytes: Number of random bytes before base64 encoding. The
|
||||
resulting string is ``ceil(nbytes * 4 / 3)`` characters
|
||||
(43 chars for the default 32 bytes). Defaults to 32.
|
||||
|
||||
Returns:
|
||||
A ready-to-use token string (e.g. ``"user_Xk3..."``).
|
||||
"""
|
||||
token = secrets.token_urlsafe(nbytes)
|
||||
if self._prefix is not None:
|
||||
return f"{self._prefix}{token}"
|
||||
return token
|
||||
142
src/fastapi_toolsets/security/sources/cookie.py
Normal file
142
src/fastapi_toolsets/security/sources/cookie.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Cookie-based authentication source."""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import inspect
|
||||
import json
|
||||
import time
|
||||
from typing import Annotated, Any, Callable
|
||||
|
||||
from fastapi import Depends, Request, Response
|
||||
from fastapi.security import APIKeyCookie, SecurityScopes
|
||||
|
||||
from fastapi_toolsets.exceptions import UnauthorizedError
|
||||
|
||||
from ..abc import AuthSource, _call_validator
|
||||
|
||||
|
||||
class CookieAuth(AuthSource):
|
||||
"""Cookie-based authentication source.
|
||||
|
||||
Wraps :class:`fastapi.security.APIKeyCookie` for OpenAPI documentation.
|
||||
Optionally signs the cookie with HMAC-SHA256 to provide stateless, tamper-
|
||||
proof sessions without any database entry.
|
||||
|
||||
Args:
|
||||
name: Cookie name.
|
||||
validator: Sync or async callable that receives the cookie value
|
||||
(plain, after signature verification when ``secret_key`` is set)
|
||||
and any extra keyword arguments, and returns the authenticated
|
||||
identity.
|
||||
secret_key: When provided, the cookie is HMAC-SHA256 signed.
|
||||
:meth:`set_cookie` embeds an expiry and signs the payload;
|
||||
:meth:`extract` verifies the signature and expiry before handing
|
||||
the plain value to the validator. When ``None`` (default), the raw
|
||||
cookie value is passed to the validator as-is.
|
||||
ttl: Cookie lifetime in seconds (default 24 h). Only used when
|
||||
``secret_key`` is set.
|
||||
**kwargs: Extra keyword arguments forwarded to the validator on every
|
||||
call (e.g. ``role=Role.ADMIN``).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
validator: Callable[..., Any],
|
||||
*,
|
||||
secret_key: str | None = None,
|
||||
ttl: int = 86400,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._name = name
|
||||
self._validator = validator
|
||||
self._secret_key = secret_key
|
||||
self._ttl = ttl
|
||||
self._kwargs = kwargs
|
||||
self._scheme = APIKeyCookie(name=name, auto_error=False)
|
||||
|
||||
_scheme = self._scheme
|
||||
_self = self
|
||||
_kwargs = kwargs
|
||||
|
||||
async def _call(
|
||||
security_scopes: SecurityScopes, # noqa: ARG001
|
||||
value: Annotated[str | None, Depends(_scheme)] = None,
|
||||
) -> Any:
|
||||
if value is None:
|
||||
raise UnauthorizedError()
|
||||
plain = _self._verify(value)
|
||||
return await _call_validator(_self._validator, plain, **_kwargs)
|
||||
|
||||
self._call_fn = _call
|
||||
self.__signature__ = inspect.signature(_call)
|
||||
|
||||
def _hmac(self, data: str) -> str:
|
||||
assert self._secret_key is not None
|
||||
return hmac.new(
|
||||
self._secret_key.encode(), data.encode(), hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
def _sign(self, value: str) -> str:
|
||||
data = base64.urlsafe_b64encode(
|
||||
json.dumps({"v": value, "exp": int(time.time()) + self._ttl}).encode()
|
||||
).decode()
|
||||
return f"{data}.{self._hmac(data)}"
|
||||
|
||||
def _verify(self, cookie_value: str) -> str:
|
||||
"""Return the plain value, verifying HMAC + expiry when signed."""
|
||||
if not self._secret_key:
|
||||
return cookie_value
|
||||
|
||||
try:
|
||||
data, sig = cookie_value.rsplit(".", 1)
|
||||
except ValueError:
|
||||
raise UnauthorizedError()
|
||||
|
||||
if not hmac.compare_digest(self._hmac(data), sig):
|
||||
raise UnauthorizedError()
|
||||
|
||||
try:
|
||||
payload = json.loads(base64.urlsafe_b64decode(data))
|
||||
value: str = payload["v"]
|
||||
exp: int = payload["exp"]
|
||||
except Exception:
|
||||
raise UnauthorizedError()
|
||||
|
||||
if exp < int(time.time()):
|
||||
raise UnauthorizedError()
|
||||
|
||||
return value
|
||||
|
||||
async def extract(self, request: Request) -> str | None:
|
||||
return request.cookies.get(self._name)
|
||||
|
||||
async def authenticate(self, credential: str) -> Any:
|
||||
plain = self._verify(credential)
|
||||
return await _call_validator(self._validator, plain, **self._kwargs)
|
||||
|
||||
def require(self, **kwargs: Any) -> "CookieAuth":
|
||||
"""Return a new instance with additional (or overriding) validator kwargs."""
|
||||
return CookieAuth(
|
||||
self._name,
|
||||
self._validator,
|
||||
secret_key=self._secret_key,
|
||||
ttl=self._ttl,
|
||||
**{**self._kwargs, **kwargs},
|
||||
)
|
||||
|
||||
def set_cookie(self, response: Response, value: str) -> None:
|
||||
"""Attach the cookie to *response*, signing it when ``secret_key`` is set."""
|
||||
cookie_value = self._sign(value) if self._secret_key else value
|
||||
response.set_cookie(
|
||||
self._name,
|
||||
cookie_value,
|
||||
httponly=True,
|
||||
samesite="lax",
|
||||
max_age=self._ttl,
|
||||
)
|
||||
|
||||
def delete_cookie(self, response: Response) -> None:
|
||||
"""Clear the session cookie (logout)."""
|
||||
response.delete_cookie(self._name, httponly=True, samesite="lax")
|
||||
71
src/fastapi_toolsets/security/sources/header.py
Normal file
71
src/fastapi_toolsets/security/sources/header.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""API key header authentication source."""
|
||||
|
||||
import inspect
|
||||
from typing import Annotated, Any, Callable
|
||||
|
||||
from fastapi import Depends, Request
|
||||
from fastapi.security import APIKeyHeader, SecurityScopes
|
||||
|
||||
from fastapi_toolsets.exceptions import UnauthorizedError
|
||||
|
||||
from ..abc import AuthSource, _call_validator
|
||||
|
||||
|
||||
class APIKeyHeaderAuth(AuthSource):
|
||||
"""API key header authentication source.
|
||||
|
||||
Wraps :class:`fastapi.security.APIKeyHeader` for OpenAPI documentation.
|
||||
The validator is called as ``await validator(api_key, **kwargs)``
|
||||
where ``kwargs`` are the extra keyword arguments provided at instantiation.
|
||||
|
||||
Args:
|
||||
name: HTTP header name that carries the API key (e.g. ``"X-API-Key"``).
|
||||
validator: Sync or async callable that receives the API key and any
|
||||
extra keyword arguments, and returns the authenticated identity.
|
||||
Should raise :class:`~fastapi_toolsets.exceptions.UnauthorizedError`
|
||||
on failure.
|
||||
**kwargs: Extra keyword arguments forwarded to the validator on every
|
||||
call (e.g. ``role=Role.ADMIN``).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
validator: Callable[..., Any],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._name = name
|
||||
self._validator = validator
|
||||
self._kwargs = kwargs
|
||||
self._scheme = APIKeyHeader(name=name, auto_error=False)
|
||||
|
||||
_scheme = self._scheme
|
||||
_validator = validator
|
||||
_kwargs = kwargs
|
||||
|
||||
async def _call(
|
||||
security_scopes: SecurityScopes, # noqa: ARG001
|
||||
api_key: Annotated[str | None, Depends(_scheme)] = None,
|
||||
) -> Any:
|
||||
if api_key is None:
|
||||
raise UnauthorizedError()
|
||||
return await _call_validator(_validator, api_key, **_kwargs)
|
||||
|
||||
self._call_fn = _call
|
||||
self.__signature__ = inspect.signature(_call)
|
||||
|
||||
async def extract(self, request: Request) -> str | None:
|
||||
"""Extract the API key from the configured header."""
|
||||
return request.headers.get(self._name) or None
|
||||
|
||||
async def authenticate(self, credential: str) -> Any:
|
||||
"""Validate a credential and return the identity."""
|
||||
return await _call_validator(self._validator, credential, **self._kwargs)
|
||||
|
||||
def require(self, **kwargs: Any) -> "APIKeyHeaderAuth":
|
||||
"""Return a new instance with additional (or overriding) validator kwargs."""
|
||||
return APIKeyHeaderAuth(
|
||||
self._name,
|
||||
self._validator,
|
||||
**{**self._kwargs, **kwargs},
|
||||
)
|
||||
121
src/fastapi_toolsets/security/sources/multi.py
Normal file
121
src/fastapi_toolsets/security/sources/multi.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""MultiAuth: combine multiple authentication sources into a single callable."""
|
||||
|
||||
import inspect
|
||||
from typing import Any, cast
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.security import SecurityScopes
|
||||
|
||||
from fastapi_toolsets.exceptions import UnauthorizedError
|
||||
|
||||
from ..abc import AuthSource
|
||||
|
||||
|
||||
class MultiAuth:
|
||||
"""Combine multiple authentication sources into a single callable.
|
||||
|
||||
Sources are tried in order; the first one whose
|
||||
:meth:`~AuthSource.extract` returns a non-``None`` credential wins.
|
||||
Its :meth:`~AuthSource.authenticate` is called and the result returned.
|
||||
|
||||
If a credential is found but the validator raises, the exception propagates
|
||||
immediately — the remaining sources are **not** tried. This prevents
|
||||
silent fallthrough on invalid credentials.
|
||||
|
||||
If no source provides a credential,
|
||||
:class:`~fastapi_toolsets.exceptions.UnauthorizedError` is raised.
|
||||
|
||||
The :meth:`~AuthSource.extract` method of each source performs only
|
||||
string matching (no I/O), so prefix-based dispatch is essentially free.
|
||||
|
||||
Any :class:`~AuthSource` subclass — including user-defined ones — can be
|
||||
passed as a source.
|
||||
|
||||
Args:
|
||||
*sources: Auth source instances to try in order.
|
||||
|
||||
Example::
|
||||
|
||||
user_bearer = BearerTokenAuth(verify_user, prefix="user_")
|
||||
org_bearer = BearerTokenAuth(verify_org, prefix="org_")
|
||||
cookie = CookieAuth("session", verify_session)
|
||||
|
||||
multi = MultiAuth(user_bearer, org_bearer, cookie)
|
||||
|
||||
@app.get("/data")
|
||||
async def data_route(user = Security(multi)):
|
||||
return user
|
||||
|
||||
# Apply a shared requirement to all sources at once
|
||||
@app.get("/admin")
|
||||
async def admin_route(user = Security(multi.require(role=Role.ADMIN))):
|
||||
return user
|
||||
"""
|
||||
|
||||
def __init__(self, *sources: AuthSource) -> None:
|
||||
self._sources = sources
|
||||
|
||||
_sources = sources
|
||||
|
||||
async def _call(
|
||||
request: Request,
|
||||
security_scopes: SecurityScopes, # noqa: ARG001
|
||||
**kwargs: Any, # noqa: ARG001 — absorbs scheme values injected by FastAPI
|
||||
) -> Any:
|
||||
for source in _sources:
|
||||
credential = await source.extract(request)
|
||||
if credential is not None:
|
||||
return await source.authenticate(credential)
|
||||
raise UnauthorizedError()
|
||||
|
||||
self._call_fn = _call
|
||||
|
||||
# Build a merged signature that includes the security-scheme Depends()
|
||||
# parameters from every source so FastAPI registers them in OpenAPI docs.
|
||||
seen: set[str] = {"request", "security_scopes"}
|
||||
merged: list[inspect.Parameter] = [
|
||||
inspect.Parameter(
|
||||
"request",
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=Request,
|
||||
),
|
||||
inspect.Parameter(
|
||||
"security_scopes",
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=SecurityScopes,
|
||||
),
|
||||
]
|
||||
for i, source in enumerate(sources):
|
||||
for name, param in inspect.signature(source).parameters.items():
|
||||
if name in seen:
|
||||
continue
|
||||
merged.append(param.replace(name=f"_s{i}_{name}"))
|
||||
seen.add(name)
|
||||
self.__signature__ = inspect.Signature(merged, return_annotation=Any)
|
||||
|
||||
async def __call__(self, **kwargs: Any) -> Any:
|
||||
return await self._call_fn(**kwargs)
|
||||
|
||||
def require(self, **kwargs: Any) -> "MultiAuth":
|
||||
"""Return a new :class:`MultiAuth` with kwargs forwarded to each source.
|
||||
|
||||
Calls ``.require(**kwargs)`` on every source that supports it. Sources
|
||||
that do not implement ``.require()`` (e.g. custom :class:`~AuthSource`
|
||||
subclasses) are passed through unchanged.
|
||||
|
||||
New kwargs are merged over each source's existing kwargs — new values
|
||||
win on conflict::
|
||||
|
||||
multi = MultiAuth(bearer, cookie)
|
||||
|
||||
@app.get("/admin")
|
||||
async def admin(user = Security(multi.require(role=Role.ADMIN))):
|
||||
return user
|
||||
"""
|
||||
new_sources = tuple(
|
||||
cast(Any, source).require(**kwargs)
|
||||
if hasattr(source, "require")
|
||||
else source
|
||||
for source in self._sources
|
||||
)
|
||||
return MultiAuth(*new_sources)
|
||||
1174
tests/test_security.py
Normal file
1174
tests/test_security.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user