mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-16 14:46:24 +02:00
Compare commits
3 Commits
5a1493266e
...
v2.2.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
dde5183e68
|
|||
|
|
e4250a9910 | ||
|
|
4800941934 |
@@ -1 +0,0 @@
|
|||||||
# Authentication
|
|
||||||
@@ -22,6 +22,8 @@ UserCrud = CrudFactory(model=User)
|
|||||||
|
|
||||||
## Basic operations
|
## Basic operations
|
||||||
|
|
||||||
|
!!! info "`get_or_none` added in `v2.2`"
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Create
|
# Create
|
||||||
user = await UserCrud.create(session=session, obj=UserCreateSchema(username="alice"))
|
user = await UserCrud.create(session=session, obj=UserCreateSchema(username="alice"))
|
||||||
@@ -29,6 +31,9 @@ user = await UserCrud.create(session=session, obj=UserCreateSchema(username="ali
|
|||||||
# Get one (raises NotFoundError if not found)
|
# Get one (raises NotFoundError if not found)
|
||||||
user = await UserCrud.get(session=session, filters=[User.id == user_id])
|
user = await UserCrud.get(session=session, filters=[User.id == user_id])
|
||||||
|
|
||||||
|
# Get one or None (never raises)
|
||||||
|
user = await UserCrud.get_or_none(session=session, filters=[User.id == user_id])
|
||||||
|
|
||||||
# Get first or None
|
# Get first or None
|
||||||
user = await UserCrud.first(session=session, filters=[User.email == email])
|
user = await UserCrud.first(session=session, filters=[User.email == email])
|
||||||
|
|
||||||
@@ -46,6 +51,36 @@ count = await UserCrud.count(session=session, filters=[User.is_active == True])
|
|||||||
exists = await UserCrud.exists(session=session, filters=[User.email == email])
|
exists = await UserCrud.exists(session=session, filters=[User.email == email])
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Fetching a single record
|
||||||
|
|
||||||
|
Three methods fetch a single record — choose based on how you want to handle the "not found" case and whether you need strict uniqueness:
|
||||||
|
|
||||||
|
| Method | Not found | Multiple results |
|
||||||
|
|---|---|---|
|
||||||
|
| `get` | raises `NotFoundError` | raises `MultipleResultsFound` |
|
||||||
|
| `get_or_none` | returns `None` | raises `MultipleResultsFound` |
|
||||||
|
| `first` | returns `None` | returns the first match silently |
|
||||||
|
|
||||||
|
Use `get` when the record must exist (e.g. a detail endpoint that should return 404):
|
||||||
|
|
||||||
|
```python
|
||||||
|
user = await UserCrud.get(session=session, filters=[User.id == user_id])
|
||||||
|
```
|
||||||
|
|
||||||
|
Use `get_or_none` when the record may not exist but you still want strict uniqueness enforcement:
|
||||||
|
|
||||||
|
```python
|
||||||
|
user = await UserCrud.get_or_none(session=session, filters=[User.email == email])
|
||||||
|
if user is None:
|
||||||
|
... # handle missing case without catching an exception
|
||||||
|
```
|
||||||
|
|
||||||
|
Use `first` when you only care about any one match and don't need uniqueness:
|
||||||
|
|
||||||
|
```python
|
||||||
|
user = await UserCrud.first(session=session, filters=[User.is_active == True])
|
||||||
|
```
|
||||||
|
|
||||||
## Pagination
|
## Pagination
|
||||||
|
|
||||||
!!! info "Added in `v1.1` (only offset_pagination via `paginate` if `<v1.1`)"
|
!!! info "Added in `v1.1` (only offset_pagination via `paginate` if `<v1.1`)"
|
||||||
|
|||||||
@@ -1,267 +0,0 @@
|
|||||||
# Security
|
|
||||||
|
|
||||||
Composable authentication helpers for FastAPI that use `Security()` for OpenAPI documentation and accept user-provided validator functions with full type flexibility.
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
The `security` module provides four auth source classes and a `MultiAuth` factory. Each class wraps a FastAPI security scheme for OpenAPI and accepts a validator function called as:
|
|
||||||
|
|
||||||
```python
|
|
||||||
await validator(credential, **kwargs)
|
|
||||||
```
|
|
||||||
|
|
||||||
where `kwargs` are the extra keyword arguments provided at instantiation (roles, permissions, enums, etc.). The validator returns the authenticated identity (e.g. a `User` model) which becomes the route dependency value.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from fastapi import Security
|
|
||||||
from fastapi_toolsets.security import BearerTokenAuth
|
|
||||||
|
|
||||||
async def verify_token(token: str, *, role: str) -> User:
|
|
||||||
user = await db.get_by_token(token)
|
|
||||||
if not user or user.role != role:
|
|
||||||
raise UnauthorizedError()
|
|
||||||
return user
|
|
||||||
|
|
||||||
bearer_admin = BearerTokenAuth(verify_token, role="admin")
|
|
||||||
|
|
||||||
@app.get("/admin")
|
|
||||||
async def admin_route(user: User = Security(bearer_admin)):
|
|
||||||
return user
|
|
||||||
```
|
|
||||||
|
|
||||||
## Auth sources
|
|
||||||
|
|
||||||
### [`BearerTokenAuth`](../reference/security.md#fastapi_toolsets.security.BearerTokenAuth)
|
|
||||||
|
|
||||||
Reads the `Authorization: Bearer <token>` header. Wraps `HTTPBearer` for OpenAPI.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from fastapi_toolsets.security import BearerTokenAuth
|
|
||||||
|
|
||||||
bearer = BearerTokenAuth(validator=verify_token)
|
|
||||||
|
|
||||||
@app.get("/me")
|
|
||||||
async def me(user: User = Security(bearer)):
|
|
||||||
return user
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Token prefix
|
|
||||||
|
|
||||||
The optional `prefix` parameter restricts a `BearerTokenAuth` instance to tokens
|
|
||||||
that start with a given string. The prefix is **kept** in the value passed to the
|
|
||||||
validator — store and compare tokens with their prefix included.
|
|
||||||
|
|
||||||
This lets you deploy multiple `BearerTokenAuth` instances in the same application
|
|
||||||
and disambiguate them efficiently in `MultiAuth`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
user_bearer = BearerTokenAuth(verify_user, prefix="user_") # matches "Bearer user_..."
|
|
||||||
org_bearer = BearerTokenAuth(verify_org, prefix="org_") # matches "Bearer org_..."
|
|
||||||
```
|
|
||||||
|
|
||||||
Use [`generate_token()`](#token-generation) to create correctly-prefixed tokens.
|
|
||||||
|
|
||||||
#### Token generation
|
|
||||||
|
|
||||||
`BearerTokenAuth.generate_token()` produces a secure random token ready to store
|
|
||||||
in your database and return to the client. If a prefix is configured it is
|
|
||||||
prepended automatically:
|
|
||||||
|
|
||||||
```python
|
|
||||||
bearer = BearerTokenAuth(verify_token, prefix="user_")
|
|
||||||
|
|
||||||
token = bearer.generate_token() # e.g. "user_Xk3mN..."
|
|
||||||
await db.store_token(user_id, token)
|
|
||||||
return {"access_token": token, "token_type": "bearer"}
|
|
||||||
```
|
|
||||||
|
|
||||||
The client sends `Authorization: Bearer user_Xk3mN...` and the validator receives
|
|
||||||
the full token (prefix included) to compare against the stored value.
|
|
||||||
|
|
||||||
### [`CookieAuth`](../reference/security.md#fastapi_toolsets.security.CookieAuth)
|
|
||||||
|
|
||||||
Reads a named cookie. Wraps `APIKeyCookie` for OpenAPI.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from fastapi_toolsets.security import CookieAuth
|
|
||||||
|
|
||||||
cookie_auth = CookieAuth("session", validator=verify_session)
|
|
||||||
|
|
||||||
@app.get("/me")
|
|
||||||
async def me(user: User = Security(cookie_auth)):
|
|
||||||
return user
|
|
||||||
```
|
|
||||||
|
|
||||||
### [`OAuth2Auth`](../reference/security.md#fastapi_toolsets.security.OAuth2Auth)
|
|
||||||
|
|
||||||
Reads the `Authorization: Bearer <token>` header and registers the token endpoint
|
|
||||||
in OpenAPI via `OAuth2PasswordBearer`.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from fastapi_toolsets.security import OAuth2Auth
|
|
||||||
|
|
||||||
oauth2_auth = OAuth2Auth(token_url="/token", validator=verify_token)
|
|
||||||
|
|
||||||
@app.get("/me")
|
|
||||||
async def me(user: User = Security(oauth2_auth)):
|
|
||||||
return user
|
|
||||||
```
|
|
||||||
|
|
||||||
### [`OpenIDAuth`](../reference/security.md#fastapi_toolsets.security.OpenIDAuth)
|
|
||||||
|
|
||||||
Reads the `Authorization: Bearer <token>` header and registers the OpenID Connect
|
|
||||||
discovery URL in OpenAPI via `OpenIdConnect`. Token validation is fully delegated
|
|
||||||
to your validator — use any OIDC / JWT library (`authlib`, `python-jose`, `PyJWT`).
|
|
||||||
|
|
||||||
```python
|
|
||||||
from fastapi_toolsets.security import OpenIDAuth
|
|
||||||
|
|
||||||
async def verify_google_token(token: str, *, audience: str) -> User:
|
|
||||||
payload = jwt.decode(token, google_public_keys, algorithms=["RS256"],
|
|
||||||
audience=audience)
|
|
||||||
return User(email=payload["email"], name=payload["name"])
|
|
||||||
|
|
||||||
google_auth = OpenIDAuth(
|
|
||||||
"https://accounts.google.com/.well-known/openid-configuration",
|
|
||||||
verify_google_token,
|
|
||||||
audience="my-client-id",
|
|
||||||
)
|
|
||||||
|
|
||||||
@app.get("/me")
|
|
||||||
async def me(user: User = Security(google_auth)):
|
|
||||||
return user
|
|
||||||
```
|
|
||||||
|
|
||||||
The discovery URL is used **only for OpenAPI documentation** — no requests are made
|
|
||||||
to it by this class. You are responsible for fetching and caching the provider's
|
|
||||||
public keys in your validator.
|
|
||||||
|
|
||||||
Multiple providers work naturally with `MultiAuth`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
multi = MultiAuth(google_auth, github_auth)
|
|
||||||
|
|
||||||
@app.get("/data")
|
|
||||||
async def data(user: User = Security(multi)):
|
|
||||||
return user
|
|
||||||
```
|
|
||||||
|
|
||||||
## Typed validator kwargs
|
|
||||||
|
|
||||||
All auth classes forward extra instantiation keyword arguments to the validator.
|
|
||||||
Arguments can be any type — enums, strings, integers, etc. The validator returns
|
|
||||||
the authenticated identity, which FastAPI injects directly into the route handler.
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def verify_token(token: str, *, role: Role, permission: str) -> User:
|
|
||||||
user = await decode_token(token)
|
|
||||||
if user.role != role or permission not in user.permissions:
|
|
||||||
raise UnauthorizedError()
|
|
||||||
return user
|
|
||||||
|
|
||||||
bearer = BearerTokenAuth(verify_token, role=Role.ADMIN, permission="billing:read")
|
|
||||||
```
|
|
||||||
|
|
||||||
Each auth instance is self-contained — create a separate instance per distinct
|
|
||||||
requirement instead of passing requirements through `Security(scopes=[...])`.
|
|
||||||
|
|
||||||
### Using `.require()` inline
|
|
||||||
|
|
||||||
If declaring a new top-level variable per role feels verbose, use `.require()` to
|
|
||||||
create a configured clone directly in the route decorator. The original instance
|
|
||||||
is not mutated:
|
|
||||||
|
|
||||||
```python
|
|
||||||
bearer = BearerTokenAuth(verify_token)
|
|
||||||
|
|
||||||
@app.get("/admin/stats")
|
|
||||||
async def admin_stats(user: User = Security(bearer.require(role=Role.ADMIN))):
|
|
||||||
return {"message": f"Hello admin {user.name}"}
|
|
||||||
|
|
||||||
@app.get("/profile")
|
|
||||||
async def profile(user: User = Security(bearer.require(role=Role.USER))):
|
|
||||||
return {"id": user.id, "name": user.name}
|
|
||||||
```
|
|
||||||
|
|
||||||
`.require()` kwargs are merged over existing ones — new values win on conflict.
|
|
||||||
The `prefix` (for `BearerTokenAuth`) and cookie name (for `CookieAuth`) are
|
|
||||||
always preserved.
|
|
||||||
|
|
||||||
`.require()` instances work transparently inside `MultiAuth`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
multi = MultiAuth(
|
|
||||||
user_bearer.require(role=Role.USER),
|
|
||||||
org_bearer.require(role=Role.ADMIN),
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
## MultiAuth
|
|
||||||
|
|
||||||
[`MultiAuth`](../reference/security.md#fastapi_toolsets.security.MultiAuth) combines
|
|
||||||
multiple auth sources into a single callable. Sources are tried in order; the
|
|
||||||
first one that finds a credential wins.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from fastapi_toolsets.security import MultiAuth
|
|
||||||
|
|
||||||
multi = MultiAuth(user_bearer, org_bearer, cookie_auth)
|
|
||||||
|
|
||||||
@app.get("/data")
|
|
||||||
async def data_route(user = Security(multi)):
|
|
||||||
return user
|
|
||||||
```
|
|
||||||
|
|
||||||
### Using `.require()` on MultiAuth
|
|
||||||
|
|
||||||
`MultiAuth` also supports `.require()`, which propagates the kwargs to every
|
|
||||||
source that implements it. Sources that do not (e.g. custom `AuthSource`
|
|
||||||
subclasses) are passed through unchanged:
|
|
||||||
|
|
||||||
```python
|
|
||||||
multi = MultiAuth(bearer, cookie)
|
|
||||||
|
|
||||||
@app.get("/admin")
|
|
||||||
async def admin(user: User = Security(multi.require(role=Role.ADMIN))):
|
|
||||||
return user
|
|
||||||
```
|
|
||||||
|
|
||||||
This is equivalent to calling `.require()` on each source individually:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# These two are identical
|
|
||||||
multi.require(role=Role.ADMIN)
|
|
||||||
|
|
||||||
MultiAuth(
|
|
||||||
bearer.require(role=Role.ADMIN),
|
|
||||||
cookie.require(role=Role.ADMIN),
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Prefix-based dispatch
|
|
||||||
|
|
||||||
Because `extract()` is pure string matching (no I/O), prefix-based source
|
|
||||||
selection is essentially free. Only the matching source's validator (which may
|
|
||||||
involve DB or network I/O) is ever called:
|
|
||||||
|
|
||||||
```python
|
|
||||||
user_bearer = BearerTokenAuth(verify_user, prefix="user_")
|
|
||||||
org_bearer = BearerTokenAuth(verify_org, prefix="org_")
|
|
||||||
|
|
||||||
multi = MultiAuth(user_bearer, org_bearer)
|
|
||||||
|
|
||||||
# "Bearer user_alice" → only verify_user runs, receives "user_alice"
|
|
||||||
# "Bearer org_acme" → only verify_org runs, receives "org_acme"
|
|
||||||
```
|
|
||||||
|
|
||||||
Tokens are stored and compared **with their prefix** — use `generate_token()` on
|
|
||||||
each source to issue correctly-prefixed tokens:
|
|
||||||
|
|
||||||
```python
|
|
||||||
user_token = user_bearer.generate_token() # "user_..."
|
|
||||||
org_token = org_bearer.generate_token() # "org_..."
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
[:material-api: API Reference](../reference/security.md)
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
# `security`
|
|
||||||
|
|
||||||
Here's the reference for the authentication helpers provided by the `security` module.
|
|
||||||
|
|
||||||
You can import them directly from `fastapi_toolsets.security`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from fastapi_toolsets.security import (
|
|
||||||
AuthSource,
|
|
||||||
BearerTokenAuth,
|
|
||||||
CookieAuth,
|
|
||||||
OAuth2Auth,
|
|
||||||
OpenIDAuth,
|
|
||||||
MultiAuth,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
## ::: fastapi_toolsets.security.AuthSource
|
|
||||||
|
|
||||||
## ::: fastapi_toolsets.security.BearerTokenAuth
|
|
||||||
|
|
||||||
## ::: fastapi_toolsets.security.CookieAuth
|
|
||||||
|
|
||||||
## ::: fastapi_toolsets.security.OAuth2Auth
|
|
||||||
|
|
||||||
## ::: fastapi_toolsets.security.OpenIDAuth
|
|
||||||
|
|
||||||
## ::: fastapi_toolsets.security.MultiAuth
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
from fastapi import FastAPI
|
|
||||||
|
|
||||||
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
|
||||||
|
|
||||||
from .routes import router
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
init_exceptions_handlers(app=app)
|
|
||||||
app.include_router(router=router)
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
from fastapi_toolsets.crud import CrudFactory
|
|
||||||
|
|
||||||
from .models import OAuthAccount, OAuthProvider, Team, User, UserToken
|
|
||||||
|
|
||||||
TeamCrud = CrudFactory(model=Team)
|
|
||||||
UserCrud = CrudFactory(model=User)
|
|
||||||
UserTokenCrud = CrudFactory(model=UserToken)
|
|
||||||
OAuthProviderCrud = CrudFactory(model=OAuthProvider)
|
|
||||||
OAuthAccountCrud = CrudFactory(model=OAuthAccount)
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
from fastapi import Depends
|
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
|
||||||
|
|
||||||
from fastapi_toolsets.db import create_db_context, create_db_dependency
|
|
||||||
|
|
||||||
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/postgres"
|
|
||||||
|
|
||||||
engine = create_async_engine(url=DATABASE_URL, future=True)
|
|
||||||
async_session_maker = async_sessionmaker(bind=engine, expire_on_commit=False)
|
|
||||||
|
|
||||||
get_db = create_db_dependency(session_maker=async_session_maker)
|
|
||||||
get_db_context = create_db_context(session_maker=async_session_maker)
|
|
||||||
|
|
||||||
|
|
||||||
SessionDep = Depends(get_db)
|
|
||||||
@@ -1,105 +0,0 @@
|
|||||||
import enum
|
|
||||||
from datetime import datetime
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from sqlalchemy import (
|
|
||||||
Boolean,
|
|
||||||
DateTime,
|
|
||||||
Enum,
|
|
||||||
ForeignKey,
|
|
||||||
Integer,
|
|
||||||
String,
|
|
||||||
UniqueConstraint,
|
|
||||||
)
|
|
||||||
from sqlalchemy.dialects.postgresql import UUID as PG_UUID
|
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
|
||||||
|
|
||||||
from fastapi_toolsets.models import TimestampMixin, UUIDMixin
|
|
||||||
|
|
||||||
|
|
||||||
class Base(DeclarativeBase, UUIDMixin):
|
|
||||||
type_annotation_map = {
|
|
||||||
str: String(),
|
|
||||||
int: Integer(),
|
|
||||||
UUID: PG_UUID(as_uuid=True),
|
|
||||||
datetime: DateTime(timezone=True),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class UserRole(enum.Enum):
|
|
||||||
admin = "admin"
|
|
||||||
moderator = "moderator"
|
|
||||||
user = "user"
|
|
||||||
|
|
||||||
|
|
||||||
class Team(Base, TimestampMixin):
|
|
||||||
__tablename__ = "teams"
|
|
||||||
|
|
||||||
name: Mapped[str] = mapped_column(String, unique=True, index=True)
|
|
||||||
users: Mapped[list["User"]] = relationship(back_populates="team")
|
|
||||||
|
|
||||||
|
|
||||||
class User(Base, TimestampMixin):
|
|
||||||
__tablename__ = "users"
|
|
||||||
|
|
||||||
username: Mapped[str] = mapped_column(String, unique=True, index=True)
|
|
||||||
email: Mapped[str | None] = mapped_column(
|
|
||||||
String, unique=True, index=True, nullable=True
|
|
||||||
)
|
|
||||||
hashed_password: Mapped[str | None] = mapped_column(String, nullable=True)
|
|
||||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
|
||||||
role: Mapped[UserRole] = mapped_column(Enum(UserRole), default=UserRole.user)
|
|
||||||
|
|
||||||
team_id: Mapped[UUID | None] = mapped_column(ForeignKey("teams.id"), nullable=True)
|
|
||||||
team: Mapped["Team | None"] = relationship(back_populates="users")
|
|
||||||
oauth_accounts: Mapped[list["OAuthAccount"]] = relationship(back_populates="user")
|
|
||||||
tokens: Mapped[list["UserToken"]] = relationship(back_populates="user")
|
|
||||||
|
|
||||||
|
|
||||||
class UserToken(Base, TimestampMixin):
|
|
||||||
"""API tokens for a user (multiple allowed)."""
|
|
||||||
|
|
||||||
__tablename__ = "user_tokens"
|
|
||||||
|
|
||||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("users.id"))
|
|
||||||
# Store hashed token value
|
|
||||||
token_hash: Mapped[str] = mapped_column(String, unique=True, index=True)
|
|
||||||
name: Mapped[str | None] = mapped_column(String, nullable=True)
|
|
||||||
expires_at: Mapped[datetime | None] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=True
|
|
||||||
)
|
|
||||||
|
|
||||||
user: Mapped["User"] = relationship(back_populates="tokens")
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthProvider(Base, TimestampMixin):
|
|
||||||
"""Configurable OAuth2 / OpenID Connect provider."""
|
|
||||||
|
|
||||||
__tablename__ = "oauth_providers"
|
|
||||||
|
|
||||||
slug: Mapped[str] = mapped_column(String, unique=True, index=True)
|
|
||||||
name: Mapped[str] = mapped_column(String)
|
|
||||||
client_id: Mapped[str] = mapped_column(String)
|
|
||||||
client_secret: Mapped[str] = mapped_column(String)
|
|
||||||
discovery_url: Mapped[str] = mapped_column(String, nullable=False)
|
|
||||||
scopes: Mapped[str] = mapped_column(String, default="openid email profile")
|
|
||||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
|
||||||
|
|
||||||
accounts: Mapped[list["OAuthAccount"]] = relationship(back_populates="provider")
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthAccount(Base, TimestampMixin):
|
|
||||||
"""OAuth2 / OpenID Connect account linked to a user."""
|
|
||||||
|
|
||||||
__tablename__ = "oauth_accounts"
|
|
||||||
__table_args__ = (
|
|
||||||
UniqueConstraint("provider_id", "subject", name="uq_oauth_provider_subject"),
|
|
||||||
)
|
|
||||||
|
|
||||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("users.id"))
|
|
||||||
provider_id: Mapped[UUID] = mapped_column(ForeignKey("oauth_providers.id"))
|
|
||||||
# OAuth `sub` / OpenID subject identifier
|
|
||||||
subject: Mapped[str] = mapped_column(String)
|
|
||||||
|
|
||||||
user: Mapped["User"] = relationship(back_populates="oauth_accounts")
|
|
||||||
provider: Mapped["OAuthProvider"] = relationship(back_populates="accounts")
|
|
||||||
@@ -1,122 +0,0 @@
|
|||||||
from typing import Annotated
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
import bcrypt
|
|
||||||
from fastapi import APIRouter, Form, HTTPException, Response, Security
|
|
||||||
|
|
||||||
from fastapi_toolsets.dependencies import PathDependency
|
|
||||||
|
|
||||||
from .crud import UserCrud, UserTokenCrud
|
|
||||||
from .db import SessionDep
|
|
||||||
from .models import OAuthProvider, User, UserToken
|
|
||||||
from .schemas import (
|
|
||||||
ApiTokenCreateRequest,
|
|
||||||
ApiTokenResponse,
|
|
||||||
RegisterRequest,
|
|
||||||
UserCreate,
|
|
||||||
UserResponse,
|
|
||||||
)
|
|
||||||
from .security import auth, cookie_auth, create_api_token
|
|
||||||
|
|
||||||
ProviderDep = PathDependency(
|
|
||||||
model=OAuthProvider,
|
|
||||||
field=OAuthProvider.slug,
|
|
||||||
session_dep=SessionDep,
|
|
||||||
param_name="slug",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def hash_password(password: str) -> str:
|
|
||||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
|
||||||
|
|
||||||
|
|
||||||
def verify_password(plain: str, hashed: str) -> bool:
|
|
||||||
return bcrypt.checkpw(plain.encode(), hashed.encode())
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/auth")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/register", response_model=UserResponse, status_code=201)
|
|
||||||
async def register(body: RegisterRequest, session: SessionDep):
|
|
||||||
existing = await UserCrud.first(
|
|
||||||
session=session, filters=[User.username == body.username]
|
|
||||||
)
|
|
||||||
if existing:
|
|
||||||
raise HTTPException(status_code=409, detail="Username already taken")
|
|
||||||
|
|
||||||
user = await UserCrud.create(
|
|
||||||
session=session,
|
|
||||||
obj=UserCreate(
|
|
||||||
username=body.username,
|
|
||||||
email=body.email,
|
|
||||||
hashed_password=hash_password(body.password),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/token", status_code=204)
|
|
||||||
async def login(
|
|
||||||
session: SessionDep,
|
|
||||||
response: Response,
|
|
||||||
username: Annotated[str, Form()],
|
|
||||||
password: Annotated[str, Form()],
|
|
||||||
):
|
|
||||||
user = await UserCrud.first(session=session, filters=[User.username == username])
|
|
||||||
|
|
||||||
if (
|
|
||||||
not user
|
|
||||||
or not user.hashed_password
|
|
||||||
or not verify_password(password, user.hashed_password)
|
|
||||||
):
|
|
||||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
|
||||||
|
|
||||||
if not user.is_active:
|
|
||||||
raise HTTPException(status_code=403, detail="Account disabled")
|
|
||||||
|
|
||||||
cookie_auth.set_cookie(response, str(user.id))
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/logout", status_code=204)
|
|
||||||
async def logout(response: Response):
|
|
||||||
cookie_auth.delete_cookie(response)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me", response_model=UserResponse)
|
|
||||||
async def me(user: User = Security(auth)):
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/tokens", response_model=ApiTokenResponse, status_code=201)
|
|
||||||
async def create_token(
|
|
||||||
body: ApiTokenCreateRequest,
|
|
||||||
user: User = Security(auth),
|
|
||||||
):
|
|
||||||
raw, token_row = await create_api_token(
|
|
||||||
user.id, name=body.name, expires_at=body.expires_at
|
|
||||||
)
|
|
||||||
return ApiTokenResponse(
|
|
||||||
id=token_row.id,
|
|
||||||
name=token_row.name,
|
|
||||||
expires_at=token_row.expires_at,
|
|
||||||
created_at=token_row.created_at,
|
|
||||||
token=raw,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/tokens/{token_id}", status_code=204)
|
|
||||||
async def revoke_token(
|
|
||||||
session: SessionDep,
|
|
||||||
token_id: UUID,
|
|
||||||
user: User = Security(auth),
|
|
||||||
):
|
|
||||||
if not await UserTokenCrud.first(
|
|
||||||
session=session,
|
|
||||||
filters=[UserToken.id == token_id, UserToken.user_id == user.id],
|
|
||||||
):
|
|
||||||
raise HTTPException(status_code=404, detail="Token not found")
|
|
||||||
await UserTokenCrud.delete(
|
|
||||||
session=session,
|
|
||||||
filters=[UserToken.id == token_id, UserToken.user_id == user.id],
|
|
||||||
)
|
|
||||||
@@ -1,64 +0,0 @@
|
|||||||
from datetime import datetime
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from pydantic import EmailStr
|
|
||||||
|
|
||||||
from fastapi_toolsets.schemas import PydanticBase
|
|
||||||
|
|
||||||
|
|
||||||
class RegisterRequest(PydanticBase):
|
|
||||||
username: str
|
|
||||||
password: str
|
|
||||||
email: EmailStr | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class UserResponse(PydanticBase):
|
|
||||||
id: UUID
|
|
||||||
username: str
|
|
||||||
email: str | None
|
|
||||||
role: str
|
|
||||||
is_active: bool
|
|
||||||
|
|
||||||
model_config = {"from_attributes": True}
|
|
||||||
|
|
||||||
|
|
||||||
class ApiTokenCreateRequest(PydanticBase):
|
|
||||||
name: str | None = None
|
|
||||||
expires_at: datetime | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ApiTokenResponse(PydanticBase):
|
|
||||||
id: UUID
|
|
||||||
name: str | None
|
|
||||||
expires_at: datetime | None
|
|
||||||
created_at: datetime
|
|
||||||
# Only populated on creation
|
|
||||||
token: str | None = None
|
|
||||||
|
|
||||||
model_config = {"from_attributes": True}
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthProviderResponse(PydanticBase):
|
|
||||||
slug: str
|
|
||||||
name: str
|
|
||||||
|
|
||||||
model_config = {"from_attributes": True}
|
|
||||||
|
|
||||||
|
|
||||||
class UserCreate(PydanticBase):
|
|
||||||
username: str
|
|
||||||
email: str | None = None
|
|
||||||
hashed_password: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class UserTokenCreate(PydanticBase):
|
|
||||||
user_id: UUID
|
|
||||||
token_hash: str
|
|
||||||
name: str | None = None
|
|
||||||
expires_at: datetime | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthAccountCreate(PydanticBase):
|
|
||||||
user_id: UUID
|
|
||||||
provider_id: UUID
|
|
||||||
subject: str
|
|
||||||
@@ -1,100 +0,0 @@
|
|||||||
import hashlib
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from fastapi import HTTPException
|
|
||||||
from sqlalchemy.orm import selectinload
|
|
||||||
|
|
||||||
from fastapi_toolsets.exceptions import UnauthorizedError
|
|
||||||
from fastapi_toolsets.security import (
|
|
||||||
APIKeyHeaderAuth,
|
|
||||||
BearerTokenAuth,
|
|
||||||
CookieAuth,
|
|
||||||
MultiAuth,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .crud import UserCrud, UserTokenCrud
|
|
||||||
from .db import get_db_context
|
|
||||||
from .models import User, UserRole, UserToken
|
|
||||||
from .schemas import UserTokenCreate
|
|
||||||
|
|
||||||
SESSION_COOKIE = "session"
|
|
||||||
SECRET_KEY = "123456789"
|
|
||||||
|
|
||||||
|
|
||||||
def _hash_token(token: str) -> str:
|
|
||||||
return hashlib.sha256(token.encode()).hexdigest()
|
|
||||||
|
|
||||||
|
|
||||||
async def _verify_token(token: str, role: UserRole | None = None) -> User:
|
|
||||||
async with get_db_context() as db:
|
|
||||||
user_token = await UserTokenCrud.first(
|
|
||||||
session=db,
|
|
||||||
filters=[UserToken.token_hash == _hash_token(token)],
|
|
||||||
load_options=[selectinload(UserToken.user)],
|
|
||||||
)
|
|
||||||
|
|
||||||
if user_token is None or not user_token.user.is_active:
|
|
||||||
raise UnauthorizedError()
|
|
||||||
|
|
||||||
if user_token.expires_at and user_token.expires_at < datetime.now(timezone.utc):
|
|
||||||
raise UnauthorizedError()
|
|
||||||
|
|
||||||
user = user_token.user
|
|
||||||
|
|
||||||
if role is not None and user.role != role:
|
|
||||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
|
||||||
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
async def _verify_cookie(user_id: str, role: UserRole | None = None) -> User:
|
|
||||||
async with get_db_context() as db:
|
|
||||||
user = await UserCrud.first(
|
|
||||||
session=db,
|
|
||||||
filters=[User.id == UUID(user_id)],
|
|
||||||
)
|
|
||||||
|
|
||||||
if not user or not user.is_active:
|
|
||||||
raise UnauthorizedError()
|
|
||||||
|
|
||||||
if role is not None and user.role != role:
|
|
||||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
|
||||||
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
bearer_auth = BearerTokenAuth(
|
|
||||||
validator=_verify_token,
|
|
||||||
prefix="ctf_",
|
|
||||||
)
|
|
||||||
header_auth = APIKeyHeaderAuth(
|
|
||||||
name="X-API-Key",
|
|
||||||
validator=_verify_token,
|
|
||||||
)
|
|
||||||
cookie_auth = CookieAuth(
|
|
||||||
name=SESSION_COOKIE,
|
|
||||||
validator=_verify_cookie,
|
|
||||||
secret_key=SECRET_KEY,
|
|
||||||
)
|
|
||||||
auth = MultiAuth(bearer_auth, header_auth, cookie_auth)
|
|
||||||
|
|
||||||
|
|
||||||
async def create_api_token(
|
|
||||||
user_id: UUID,
|
|
||||||
*,
|
|
||||||
name: str | None = None,
|
|
||||||
expires_at: datetime | None = None,
|
|
||||||
) -> tuple[str, UserToken]:
|
|
||||||
raw = bearer_auth.generate_token()
|
|
||||||
async with get_db_context() as db:
|
|
||||||
token_row = await UserTokenCrud.create(
|
|
||||||
session=db,
|
|
||||||
obj=UserTokenCreate(
|
|
||||||
user_id=user_id,
|
|
||||||
token_hash=_hash_token(raw),
|
|
||||||
name=name,
|
|
||||||
expires_at=expires_at,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return raw, token_row
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "fastapi-toolsets"
|
name = "fastapi-toolsets"
|
||||||
version = "2.1.0"
|
version = "2.2.0"
|
||||||
description = "Production-ready utilities for FastAPI applications"
|
description = "Production-ready utilities for FastAPI applications"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
|||||||
@@ -21,4 +21,4 @@ Example usage:
|
|||||||
return Response(data={"user": user.username}, message="Success")
|
return Response(data={"user": user.username}, message="Success")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "2.1.0"
|
__version__ = "2.2.0"
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from typing import Any, ClassVar, Generic, Literal, Self, cast, overload
|
|||||||
from fastapi import Query
|
from fastapi import Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import Date, DateTime, Float, Integer, Numeric, Uuid, and_, func, select
|
from sqlalchemy import Date, DateTime, Float, Integer, Numeric, Uuid, and_, func, select
|
||||||
from sqlalchemy import delete as sql_delete
|
|
||||||
from sqlalchemy.dialects.postgresql import insert
|
from sqlalchemy.dialects.postgresql import insert
|
||||||
from sqlalchemy.exc import NoResultFound
|
from sqlalchemy.exc import NoResultFound
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
@@ -410,6 +409,82 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
NotFoundError: If no record found
|
NotFoundError: If no record found
|
||||||
MultipleResultsFound: If more than one record found
|
MultipleResultsFound: If more than one record found
|
||||||
"""
|
"""
|
||||||
|
result = await cls.get_or_none(
|
||||||
|
session,
|
||||||
|
filters,
|
||||||
|
joins=joins,
|
||||||
|
outer_join=outer_join,
|
||||||
|
with_for_update=with_for_update,
|
||||||
|
load_options=load_options,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
if result is None:
|
||||||
|
raise NotFoundError()
|
||||||
|
return result
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@classmethod
|
||||||
|
async def get_or_none( # pragma: no cover
|
||||||
|
cls: type[Self],
|
||||||
|
session: AsyncSession,
|
||||||
|
filters: list[Any],
|
||||||
|
*,
|
||||||
|
joins: JoinType | None = None,
|
||||||
|
outer_join: bool = False,
|
||||||
|
with_for_update: bool = False,
|
||||||
|
load_options: list[ExecutableOption] | None = None,
|
||||||
|
schema: type[SchemaType],
|
||||||
|
) -> Response[SchemaType] | None: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@classmethod
|
||||||
|
async def get_or_none( # pragma: no cover
|
||||||
|
cls: type[Self],
|
||||||
|
session: AsyncSession,
|
||||||
|
filters: list[Any],
|
||||||
|
*,
|
||||||
|
joins: JoinType | None = None,
|
||||||
|
outer_join: bool = False,
|
||||||
|
with_for_update: bool = False,
|
||||||
|
load_options: list[ExecutableOption] | None = None,
|
||||||
|
schema: None = ...,
|
||||||
|
) -> ModelType | None: ...
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_or_none(
|
||||||
|
cls: type[Self],
|
||||||
|
session: AsyncSession,
|
||||||
|
filters: list[Any],
|
||||||
|
*,
|
||||||
|
joins: JoinType | None = None,
|
||||||
|
outer_join: bool = False,
|
||||||
|
with_for_update: bool = False,
|
||||||
|
load_options: list[ExecutableOption] | None = None,
|
||||||
|
schema: type[BaseModel] | None = None,
|
||||||
|
) -> ModelType | Response[Any] | None:
|
||||||
|
"""Get exactly one record, or ``None`` if not found.
|
||||||
|
|
||||||
|
Like :meth:`get` but returns ``None`` instead of raising
|
||||||
|
:class:`~fastapi_toolsets.exceptions.NotFoundError` when no record
|
||||||
|
matches the filters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: DB async session
|
||||||
|
filters: List of SQLAlchemy filter conditions
|
||||||
|
joins: List of (model, condition) tuples for joining related tables
|
||||||
|
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||||
|
with_for_update: Lock the row for update
|
||||||
|
load_options: SQLAlchemy loader options (e.g., selectinload)
|
||||||
|
schema: Pydantic schema to serialize the result into. When provided,
|
||||||
|
the result is automatically wrapped in a ``Response[schema]``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model instance, ``Response[schema]`` when ``schema`` is given,
|
||||||
|
or ``None`` when no record matches.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MultipleResultsFound: If more than one record found
|
||||||
|
"""
|
||||||
q = select(cls.model)
|
q = select(cls.model)
|
||||||
q = _apply_joins(q, joins, outer_join)
|
q = _apply_joins(q, joins, outer_join)
|
||||||
q = q.where(and_(*filters))
|
q = q.where(and_(*filters))
|
||||||
@@ -419,12 +494,40 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
q = q.with_for_update()
|
q = q.with_for_update()
|
||||||
result = await session.execute(q)
|
result = await session.execute(q)
|
||||||
item = result.unique().scalar_one_or_none()
|
item = result.unique().scalar_one_or_none()
|
||||||
if not item:
|
if item is None:
|
||||||
raise NotFoundError()
|
return None
|
||||||
result = cast(ModelType, item)
|
db_model = cast(ModelType, item)
|
||||||
if schema:
|
if schema:
|
||||||
return Response(data=schema.model_validate(result))
|
return Response(data=schema.model_validate(db_model))
|
||||||
return result
|
return db_model
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@classmethod
|
||||||
|
async def first( # pragma: no cover
|
||||||
|
cls: type[Self],
|
||||||
|
session: AsyncSession,
|
||||||
|
filters: list[Any] | None = None,
|
||||||
|
*,
|
||||||
|
joins: JoinType | None = None,
|
||||||
|
outer_join: bool = False,
|
||||||
|
with_for_update: bool = False,
|
||||||
|
load_options: list[ExecutableOption] | None = None,
|
||||||
|
schema: type[SchemaType],
|
||||||
|
) -> Response[SchemaType] | None: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@classmethod
|
||||||
|
async def first( # pragma: no cover
|
||||||
|
cls: type[Self],
|
||||||
|
session: AsyncSession,
|
||||||
|
filters: list[Any] | None = None,
|
||||||
|
*,
|
||||||
|
joins: JoinType | None = None,
|
||||||
|
outer_join: bool = False,
|
||||||
|
with_for_update: bool = False,
|
||||||
|
load_options: list[ExecutableOption] | None = None,
|
||||||
|
schema: None = ...,
|
||||||
|
) -> ModelType | None: ...
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def first(
|
async def first(
|
||||||
@@ -434,8 +537,10 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
*,
|
*,
|
||||||
joins: JoinType | None = None,
|
joins: JoinType | None = None,
|
||||||
outer_join: bool = False,
|
outer_join: bool = False,
|
||||||
|
with_for_update: bool = False,
|
||||||
load_options: list[ExecutableOption] | None = None,
|
load_options: list[ExecutableOption] | None = None,
|
||||||
) -> ModelType | None:
|
schema: type[BaseModel] | None = None,
|
||||||
|
) -> ModelType | Response[Any] | None:
|
||||||
"""Get the first matching record, or None.
|
"""Get the first matching record, or None.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -443,10 +548,14 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
filters: List of SQLAlchemy filter conditions
|
filters: List of SQLAlchemy filter conditions
|
||||||
joins: List of (model, condition) tuples for joining related tables
|
joins: List of (model, condition) tuples for joining related tables
|
||||||
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||||
load_options: SQLAlchemy loader options
|
with_for_update: Lock the row for update
|
||||||
|
load_options: SQLAlchemy loader options (e.g., selectinload)
|
||||||
|
schema: Pydantic schema to serialize the result into. When provided,
|
||||||
|
the result is automatically wrapped in a ``Response[schema]``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Model instance or None
|
Model instance, ``Response[schema]`` when ``schema`` is given,
|
||||||
|
or ``None`` when no record matches.
|
||||||
"""
|
"""
|
||||||
q = select(cls.model)
|
q = select(cls.model)
|
||||||
q = _apply_joins(q, joins, outer_join)
|
q = _apply_joins(q, joins, outer_join)
|
||||||
@@ -454,8 +563,16 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
q = q.where(and_(*filters))
|
q = q.where(and_(*filters))
|
||||||
if resolved := cls._resolve_load_options(load_options):
|
if resolved := cls._resolve_load_options(load_options):
|
||||||
q = q.options(*resolved)
|
q = q.options(*resolved)
|
||||||
|
if with_for_update:
|
||||||
|
q = q.with_for_update()
|
||||||
result = await session.execute(q)
|
result = await session.execute(q)
|
||||||
return cast(ModelType | None, result.unique().scalars().first())
|
item = result.unique().scalars().first()
|
||||||
|
if item is None:
|
||||||
|
return None
|
||||||
|
db_model = cast(ModelType, item)
|
||||||
|
if schema:
|
||||||
|
return Response(data=schema.model_validate(db_model))
|
||||||
|
return db_model
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_multi(
|
async def get_multi(
|
||||||
@@ -674,8 +791,10 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
``None``, or ``Response[None]`` when ``return_response=True``.
|
``None``, or ``Response[None]`` when ``return_response=True``.
|
||||||
"""
|
"""
|
||||||
async with get_transaction(session):
|
async with get_transaction(session):
|
||||||
q = sql_delete(cls.model).where(and_(*filters))
|
result = await session.execute(select(cls.model).where(and_(*filters)))
|
||||||
await session.execute(q)
|
objects = result.scalars().all()
|
||||||
|
for obj in objects:
|
||||||
|
await session.delete(obj)
|
||||||
if return_response:
|
if return_response:
|
||||||
return Response(data=None)
|
return Response(data=None)
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -1,24 +0,0 @@
|
|||||||
"""Authentication helpers for FastAPI using Security()."""
|
|
||||||
|
|
||||||
from .abc import AuthSource
|
|
||||||
from .oauth import (
|
|
||||||
oauth_build_authorization_redirect,
|
|
||||||
oauth_decode_state,
|
|
||||||
oauth_encode_state,
|
|
||||||
oauth_fetch_userinfo,
|
|
||||||
oauth_resolve_provider_urls,
|
|
||||||
)
|
|
||||||
from .sources import APIKeyHeaderAuth, BearerTokenAuth, CookieAuth, MultiAuth
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"APIKeyHeaderAuth",
|
|
||||||
"AuthSource",
|
|
||||||
"BearerTokenAuth",
|
|
||||||
"CookieAuth",
|
|
||||||
"MultiAuth",
|
|
||||||
"oauth_build_authorization_redirect",
|
|
||||||
"oauth_decode_state",
|
|
||||||
"oauth_encode_state",
|
|
||||||
"oauth_fetch_userinfo",
|
|
||||||
"oauth_resolve_provider_urls",
|
|
||||||
]
|
|
||||||
@@ -1,51 +0,0 @@
|
|||||||
"""Abstract base class for authentication sources."""
|
|
||||||
|
|
||||||
import inspect
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any, Callable
|
|
||||||
|
|
||||||
from fastapi import Request
|
|
||||||
from fastapi.security import SecurityScopes
|
|
||||||
|
|
||||||
from fastapi_toolsets.exceptions import UnauthorizedError
|
|
||||||
|
|
||||||
|
|
||||||
async def _call_validator(
|
|
||||||
validator: Callable[..., Any], *args: Any, **kwargs: Any
|
|
||||||
) -> Any:
|
|
||||||
"""Call *validator* with *args* and *kwargs*, awaiting it if it is a coroutine function."""
|
|
||||||
if inspect.iscoroutinefunction(validator):
|
|
||||||
return await validator(*args, **kwargs)
|
|
||||||
return validator(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class AuthSource(ABC):
|
|
||||||
"""Abstract base class for authentication sources."""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
"""Set up the default FastAPI dependency signature."""
|
|
||||||
source = self
|
|
||||||
|
|
||||||
async def _call(
|
|
||||||
request: Request,
|
|
||||||
security_scopes: SecurityScopes, # noqa: ARG001
|
|
||||||
) -> Any:
|
|
||||||
credential = await source.extract(request)
|
|
||||||
if credential is None:
|
|
||||||
raise UnauthorizedError()
|
|
||||||
return await source.authenticate(credential)
|
|
||||||
|
|
||||||
self._call_fn: Callable[..., Any] = _call
|
|
||||||
self.__signature__ = inspect.signature(_call)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def extract(self, request: Request) -> str | None:
|
|
||||||
"""Extract the raw credential from the request without validating."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def authenticate(self, credential: str) -> Any:
|
|
||||||
"""Validate a credential and return the authenticated identity."""
|
|
||||||
|
|
||||||
async def __call__(self, **kwargs: Any) -> Any:
|
|
||||||
"""FastAPI dependency dispatch."""
|
|
||||||
return await self._call_fn(**kwargs)
|
|
||||||
@@ -1,140 +0,0 @@
|
|||||||
"""OAuth 2.0 / OIDC helper utilities."""
|
|
||||||
|
|
||||||
import base64
|
|
||||||
from typing import Any
|
|
||||||
from urllib.parse import urlencode
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
from fastapi.responses import RedirectResponse
|
|
||||||
|
|
||||||
_discovery_cache: dict[str, dict] = {}
|
|
||||||
|
|
||||||
|
|
||||||
async def oauth_resolve_provider_urls(
|
|
||||||
discovery_url: str,
|
|
||||||
) -> tuple[str, str, str | None]:
|
|
||||||
"""Fetch the OIDC discovery document and return endpoint URLs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
discovery_url: URL of the provider's ``/.well-known/openid-configuration``.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A ``(authorization_url, token_url, userinfo_url)`` tuple.
|
|
||||||
*userinfo_url* is ``None`` when the provider does not advertise one.
|
|
||||||
"""
|
|
||||||
if discovery_url not in _discovery_cache:
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
resp = await client.get(discovery_url)
|
|
||||||
resp.raise_for_status()
|
|
||||||
_discovery_cache[discovery_url] = resp.json()
|
|
||||||
cfg = _discovery_cache[discovery_url]
|
|
||||||
return (
|
|
||||||
cfg["authorization_endpoint"],
|
|
||||||
cfg["token_endpoint"],
|
|
||||||
cfg.get("userinfo_endpoint"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def oauth_fetch_userinfo(
|
|
||||||
*,
|
|
||||||
token_url: str,
|
|
||||||
userinfo_url: str,
|
|
||||||
code: str,
|
|
||||||
client_id: str,
|
|
||||||
client_secret: str,
|
|
||||||
redirect_uri: str,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Exchange an authorization code for tokens and return the userinfo payload.
|
|
||||||
|
|
||||||
Performs the two-step OAuth 2.0 / OIDC token exchange:
|
|
||||||
|
|
||||||
1. POSTs the authorization *code* to *token_url* to obtain an access token.
|
|
||||||
2. GETs *userinfo_url* using that access token as a Bearer credential.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token_url: Provider's token endpoint.
|
|
||||||
userinfo_url: Provider's userinfo endpoint.
|
|
||||||
code: Authorization code received from the provider's callback.
|
|
||||||
client_id: OAuth application client ID.
|
|
||||||
client_secret: OAuth application client secret.
|
|
||||||
redirect_uri: Redirect URI that was used in the authorization request.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The JSON payload returned by the userinfo endpoint as a plain ``dict``.
|
|
||||||
"""
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
token_resp = await client.post(
|
|
||||||
token_url,
|
|
||||||
data={
|
|
||||||
"grant_type": "authorization_code",
|
|
||||||
"code": code,
|
|
||||||
"client_id": client_id,
|
|
||||||
"client_secret": client_secret,
|
|
||||||
"redirect_uri": redirect_uri,
|
|
||||||
},
|
|
||||||
headers={"Accept": "application/json"},
|
|
||||||
)
|
|
||||||
token_resp.raise_for_status()
|
|
||||||
access_token = token_resp.json()["access_token"]
|
|
||||||
|
|
||||||
userinfo_resp = await client.get(
|
|
||||||
userinfo_url,
|
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
|
||||||
)
|
|
||||||
userinfo_resp.raise_for_status()
|
|
||||||
return userinfo_resp.json()
|
|
||||||
|
|
||||||
|
|
||||||
def oauth_build_authorization_redirect(
|
|
||||||
authorization_url: str,
|
|
||||||
*,
|
|
||||||
client_id: str,
|
|
||||||
scopes: str,
|
|
||||||
redirect_uri: str,
|
|
||||||
destination: str,
|
|
||||||
) -> RedirectResponse:
|
|
||||||
"""Return an OAuth 2.0 authorization ``RedirectResponse``.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
authorization_url: Provider's authorization endpoint.
|
|
||||||
client_id: OAuth application client ID.
|
|
||||||
scopes: Space-separated list of requested scopes.
|
|
||||||
redirect_uri: URI the provider should redirect back to after authorization.
|
|
||||||
destination: URL the user should be sent to after the full OAuth flow
|
|
||||||
completes (encoded as ``state``).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A :class:`~fastapi.responses.RedirectResponse` to the provider's
|
|
||||||
authorization page.
|
|
||||||
"""
|
|
||||||
params = urlencode(
|
|
||||||
{
|
|
||||||
"client_id": client_id,
|
|
||||||
"response_type": "code",
|
|
||||||
"scope": scopes,
|
|
||||||
"redirect_uri": redirect_uri,
|
|
||||||
"state": oauth_encode_state(destination),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return RedirectResponse(f"{authorization_url}?{params}")
|
|
||||||
|
|
||||||
|
|
||||||
def oauth_encode_state(url: str) -> str:
|
|
||||||
"""Base64url-encode a URL to embed as an OAuth ``state`` parameter."""
|
|
||||||
return base64.urlsafe_b64encode(url.encode()).decode()
|
|
||||||
|
|
||||||
|
|
||||||
def oauth_decode_state(state: str | None, *, fallback: str) -> str:
|
|
||||||
"""Decode a base64url OAuth ``state`` parameter.
|
|
||||||
|
|
||||||
Handles missing padding (some providers strip ``=``).
|
|
||||||
Returns *fallback* if *state* is absent, the literal string ``"null"``,
|
|
||||||
or cannot be decoded.
|
|
||||||
"""
|
|
||||||
if not state or state == "null":
|
|
||||||
return fallback
|
|
||||||
try:
|
|
||||||
padded = state + "=" * (4 - len(state) % 4)
|
|
||||||
return base64.urlsafe_b64decode(padded).decode()
|
|
||||||
except Exception:
|
|
||||||
return fallback
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
"""Built-in authentication source implementations."""
|
|
||||||
|
|
||||||
from .header import APIKeyHeaderAuth
|
|
||||||
from .bearer import BearerTokenAuth
|
|
||||||
from .cookie import CookieAuth
|
|
||||||
from .multi import MultiAuth
|
|
||||||
|
|
||||||
__all__ = ["APIKeyHeaderAuth", "BearerTokenAuth", "CookieAuth", "MultiAuth"]
|
|
||||||
@@ -1,122 +0,0 @@
|
|||||||
"""Bearer token authentication source."""
|
|
||||||
|
|
||||||
import inspect
|
|
||||||
import secrets
|
|
||||||
from typing import Annotated, Any, Callable
|
|
||||||
|
|
||||||
from fastapi import Depends
|
|
||||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer, SecurityScopes
|
|
||||||
|
|
||||||
from fastapi_toolsets.exceptions import UnauthorizedError
|
|
||||||
|
|
||||||
from ..abc import AuthSource, _call_validator
|
|
||||||
|
|
||||||
|
|
||||||
class BearerTokenAuth(AuthSource):
|
|
||||||
"""Bearer token authentication source.
|
|
||||||
|
|
||||||
Wraps :class:`fastapi.security.HTTPBearer` for OpenAPI documentation.
|
|
||||||
The validator is called as ``await validator(credential, **kwargs)``
|
|
||||||
where ``kwargs`` are the extra keyword arguments provided at instantiation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
validator: Sync or async callable that receives the credential and any
|
|
||||||
extra keyword arguments, and returns the authenticated identity
|
|
||||||
(e.g. a ``User`` model). Should raise
|
|
||||||
:class:`~fastapi_toolsets.exceptions.UnauthorizedError` on failure.
|
|
||||||
prefix: Optional token prefix (e.g. ``"user_"``). If set, only tokens
|
|
||||||
whose value starts with this prefix are matched. The prefix is
|
|
||||||
**kept** in the value passed to the validator — store and compare
|
|
||||||
tokens with their prefix included. Use :meth:`generate_token` to
|
|
||||||
create correctly-prefixed tokens. This enables multiple
|
|
||||||
``BearerTokenAuth`` instances in the same app (e.g. ``"user_"``
|
|
||||||
for user tokens, ``"org_"`` for org tokens).
|
|
||||||
**kwargs: Extra keyword arguments forwarded to the validator on every
|
|
||||||
call (e.g. ``role=Role.ADMIN``).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
validator: Callable[..., Any],
|
|
||||||
*,
|
|
||||||
prefix: str | None = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
self._validator = validator
|
|
||||||
self._prefix = prefix
|
|
||||||
self._kwargs = kwargs
|
|
||||||
self._scheme = HTTPBearer(auto_error=False)
|
|
||||||
|
|
||||||
_scheme = self._scheme
|
|
||||||
_validator = validator
|
|
||||||
_kwargs = kwargs
|
|
||||||
_prefix = prefix
|
|
||||||
|
|
||||||
async def _call(
|
|
||||||
security_scopes: SecurityScopes, # noqa: ARG001
|
|
||||||
credentials: Annotated[
|
|
||||||
HTTPAuthorizationCredentials | None, Depends(_scheme)
|
|
||||||
] = None,
|
|
||||||
) -> Any:
|
|
||||||
if credentials is None:
|
|
||||||
raise UnauthorizedError()
|
|
||||||
token = credentials.credentials
|
|
||||||
if _prefix is not None and not token.startswith(_prefix):
|
|
||||||
raise UnauthorizedError()
|
|
||||||
return await _call_validator(_validator, token, **_kwargs)
|
|
||||||
|
|
||||||
self._call_fn = _call
|
|
||||||
self.__signature__ = inspect.signature(_call)
|
|
||||||
|
|
||||||
async def extract(self, request: Any) -> str | None:
|
|
||||||
"""Extract the raw credential from the request without validating.
|
|
||||||
|
|
||||||
Returns ``None`` if no ``Authorization: Bearer`` header is present,
|
|
||||||
the token is empty, or the token does not match the configured prefix.
|
|
||||||
The prefix is included in the returned value.
|
|
||||||
"""
|
|
||||||
auth = request.headers.get("Authorization", "")
|
|
||||||
if not auth.startswith("Bearer "):
|
|
||||||
return None
|
|
||||||
token = auth[7:]
|
|
||||||
if not token:
|
|
||||||
return None
|
|
||||||
if self._prefix is not None and not token.startswith(self._prefix):
|
|
||||||
return None
|
|
||||||
return token
|
|
||||||
|
|
||||||
async def authenticate(self, credential: str) -> Any:
|
|
||||||
"""Validate a credential and return the identity.
|
|
||||||
|
|
||||||
Calls ``await validator(credential, **kwargs)`` where ``kwargs`` are
|
|
||||||
the extra keyword arguments provided at instantiation.
|
|
||||||
"""
|
|
||||||
return await _call_validator(self._validator, credential, **self._kwargs)
|
|
||||||
|
|
||||||
def require(self, **kwargs: Any) -> "BearerTokenAuth":
|
|
||||||
"""Return a new instance with additional (or overriding) validator kwargs."""
|
|
||||||
return BearerTokenAuth(
|
|
||||||
self._validator,
|
|
||||||
prefix=self._prefix,
|
|
||||||
**{**self._kwargs, **kwargs},
|
|
||||||
)
|
|
||||||
|
|
||||||
def generate_token(self, nbytes: int = 32) -> str:
|
|
||||||
"""Generate a secure random token for this auth source.
|
|
||||||
|
|
||||||
Returns a URL-safe random token. If a prefix is configured it is
|
|
||||||
prepended — the returned value is what you store in your database
|
|
||||||
and return to the client as-is.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
nbytes: Number of random bytes before base64 encoding. The
|
|
||||||
resulting string is ``ceil(nbytes * 4 / 3)`` characters
|
|
||||||
(43 chars for the default 32 bytes). Defaults to 32.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A ready-to-use token string (e.g. ``"user_Xk3..."``).
|
|
||||||
"""
|
|
||||||
token = secrets.token_urlsafe(nbytes)
|
|
||||||
if self._prefix is not None:
|
|
||||||
return f"{self._prefix}{token}"
|
|
||||||
return token
|
|
||||||
@@ -1,142 +0,0 @@
|
|||||||
"""Cookie-based authentication source."""
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import hashlib
|
|
||||||
import hmac
|
|
||||||
import inspect
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from typing import Annotated, Any, Callable
|
|
||||||
|
|
||||||
from fastapi import Depends, Request, Response
|
|
||||||
from fastapi.security import APIKeyCookie, SecurityScopes
|
|
||||||
|
|
||||||
from fastapi_toolsets.exceptions import UnauthorizedError
|
|
||||||
|
|
||||||
from ..abc import AuthSource, _call_validator
|
|
||||||
|
|
||||||
|
|
||||||
class CookieAuth(AuthSource):
|
|
||||||
"""Cookie-based authentication source.
|
|
||||||
|
|
||||||
Wraps :class:`fastapi.security.APIKeyCookie` for OpenAPI documentation.
|
|
||||||
Optionally signs the cookie with HMAC-SHA256 to provide stateless, tamper-
|
|
||||||
proof sessions without any database entry.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Cookie name.
|
|
||||||
validator: Sync or async callable that receives the cookie value
|
|
||||||
(plain, after signature verification when ``secret_key`` is set)
|
|
||||||
and any extra keyword arguments, and returns the authenticated
|
|
||||||
identity.
|
|
||||||
secret_key: When provided, the cookie is HMAC-SHA256 signed.
|
|
||||||
:meth:`set_cookie` embeds an expiry and signs the payload;
|
|
||||||
:meth:`extract` verifies the signature and expiry before handing
|
|
||||||
the plain value to the validator. When ``None`` (default), the raw
|
|
||||||
cookie value is passed to the validator as-is.
|
|
||||||
ttl: Cookie lifetime in seconds (default 24 h). Only used when
|
|
||||||
``secret_key`` is set.
|
|
||||||
**kwargs: Extra keyword arguments forwarded to the validator on every
|
|
||||||
call (e.g. ``role=Role.ADMIN``).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
validator: Callable[..., Any],
|
|
||||||
*,
|
|
||||||
secret_key: str | None = None,
|
|
||||||
ttl: int = 86400,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
self._name = name
|
|
||||||
self._validator = validator
|
|
||||||
self._secret_key = secret_key
|
|
||||||
self._ttl = ttl
|
|
||||||
self._kwargs = kwargs
|
|
||||||
self._scheme = APIKeyCookie(name=name, auto_error=False)
|
|
||||||
|
|
||||||
_scheme = self._scheme
|
|
||||||
_self = self
|
|
||||||
_kwargs = kwargs
|
|
||||||
|
|
||||||
async def _call(
|
|
||||||
security_scopes: SecurityScopes, # noqa: ARG001
|
|
||||||
value: Annotated[str | None, Depends(_scheme)] = None,
|
|
||||||
) -> Any:
|
|
||||||
if value is None:
|
|
||||||
raise UnauthorizedError()
|
|
||||||
plain = _self._verify(value)
|
|
||||||
return await _call_validator(_self._validator, plain, **_kwargs)
|
|
||||||
|
|
||||||
self._call_fn = _call
|
|
||||||
self.__signature__ = inspect.signature(_call)
|
|
||||||
|
|
||||||
def _hmac(self, data: str) -> str:
|
|
||||||
assert self._secret_key is not None
|
|
||||||
return hmac.new(
|
|
||||||
self._secret_key.encode(), data.encode(), hashlib.sha256
|
|
||||||
).hexdigest()
|
|
||||||
|
|
||||||
def _sign(self, value: str) -> str:
|
|
||||||
data = base64.urlsafe_b64encode(
|
|
||||||
json.dumps({"v": value, "exp": int(time.time()) + self._ttl}).encode()
|
|
||||||
).decode()
|
|
||||||
return f"{data}.{self._hmac(data)}"
|
|
||||||
|
|
||||||
def _verify(self, cookie_value: str) -> str:
|
|
||||||
"""Return the plain value, verifying HMAC + expiry when signed."""
|
|
||||||
if not self._secret_key:
|
|
||||||
return cookie_value
|
|
||||||
|
|
||||||
try:
|
|
||||||
data, sig = cookie_value.rsplit(".", 1)
|
|
||||||
except ValueError:
|
|
||||||
raise UnauthorizedError()
|
|
||||||
|
|
||||||
if not hmac.compare_digest(self._hmac(data), sig):
|
|
||||||
raise UnauthorizedError()
|
|
||||||
|
|
||||||
try:
|
|
||||||
payload = json.loads(base64.urlsafe_b64decode(data))
|
|
||||||
value: str = payload["v"]
|
|
||||||
exp: int = payload["exp"]
|
|
||||||
except Exception:
|
|
||||||
raise UnauthorizedError()
|
|
||||||
|
|
||||||
if exp < int(time.time()):
|
|
||||||
raise UnauthorizedError()
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
async def extract(self, request: Request) -> str | None:
|
|
||||||
return request.cookies.get(self._name)
|
|
||||||
|
|
||||||
async def authenticate(self, credential: str) -> Any:
|
|
||||||
plain = self._verify(credential)
|
|
||||||
return await _call_validator(self._validator, plain, **self._kwargs)
|
|
||||||
|
|
||||||
def require(self, **kwargs: Any) -> "CookieAuth":
|
|
||||||
"""Return a new instance with additional (or overriding) validator kwargs."""
|
|
||||||
return CookieAuth(
|
|
||||||
self._name,
|
|
||||||
self._validator,
|
|
||||||
secret_key=self._secret_key,
|
|
||||||
ttl=self._ttl,
|
|
||||||
**{**self._kwargs, **kwargs},
|
|
||||||
)
|
|
||||||
|
|
||||||
def set_cookie(self, response: Response, value: str) -> None:
|
|
||||||
"""Attach the cookie to *response*, signing it when ``secret_key`` is set."""
|
|
||||||
cookie_value = self._sign(value) if self._secret_key else value
|
|
||||||
response.set_cookie(
|
|
||||||
self._name,
|
|
||||||
cookie_value,
|
|
||||||
httponly=True,
|
|
||||||
samesite="lax",
|
|
||||||
max_age=self._ttl,
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete_cookie(self, response: Response) -> None:
|
|
||||||
"""Clear the session cookie (logout)."""
|
|
||||||
response.delete_cookie(self._name, httponly=True, samesite="lax")
|
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
"""API key header authentication source."""
|
|
||||||
|
|
||||||
import inspect
|
|
||||||
from typing import Annotated, Any, Callable
|
|
||||||
|
|
||||||
from fastapi import Depends, Request
|
|
||||||
from fastapi.security import APIKeyHeader, SecurityScopes
|
|
||||||
|
|
||||||
from fastapi_toolsets.exceptions import UnauthorizedError
|
|
||||||
|
|
||||||
from ..abc import AuthSource, _call_validator
|
|
||||||
|
|
||||||
|
|
||||||
class APIKeyHeaderAuth(AuthSource):
|
|
||||||
"""API key header authentication source.
|
|
||||||
|
|
||||||
Wraps :class:`fastapi.security.APIKeyHeader` for OpenAPI documentation.
|
|
||||||
The validator is called as ``await validator(api_key, **kwargs)``
|
|
||||||
where ``kwargs`` are the extra keyword arguments provided at instantiation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: HTTP header name that carries the API key (e.g. ``"X-API-Key"``).
|
|
||||||
validator: Sync or async callable that receives the API key and any
|
|
||||||
extra keyword arguments, and returns the authenticated identity.
|
|
||||||
Should raise :class:`~fastapi_toolsets.exceptions.UnauthorizedError`
|
|
||||||
on failure.
|
|
||||||
**kwargs: Extra keyword arguments forwarded to the validator on every
|
|
||||||
call (e.g. ``role=Role.ADMIN``).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
validator: Callable[..., Any],
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
self._name = name
|
|
||||||
self._validator = validator
|
|
||||||
self._kwargs = kwargs
|
|
||||||
self._scheme = APIKeyHeader(name=name, auto_error=False)
|
|
||||||
|
|
||||||
_scheme = self._scheme
|
|
||||||
_validator = validator
|
|
||||||
_kwargs = kwargs
|
|
||||||
|
|
||||||
async def _call(
|
|
||||||
security_scopes: SecurityScopes, # noqa: ARG001
|
|
||||||
api_key: Annotated[str | None, Depends(_scheme)] = None,
|
|
||||||
) -> Any:
|
|
||||||
if api_key is None:
|
|
||||||
raise UnauthorizedError()
|
|
||||||
return await _call_validator(_validator, api_key, **_kwargs)
|
|
||||||
|
|
||||||
self._call_fn = _call
|
|
||||||
self.__signature__ = inspect.signature(_call)
|
|
||||||
|
|
||||||
async def extract(self, request: Request) -> str | None:
|
|
||||||
"""Extract the API key from the configured header."""
|
|
||||||
return request.headers.get(self._name) or None
|
|
||||||
|
|
||||||
async def authenticate(self, credential: str) -> Any:
|
|
||||||
"""Validate a credential and return the identity."""
|
|
||||||
return await _call_validator(self._validator, credential, **self._kwargs)
|
|
||||||
|
|
||||||
def require(self, **kwargs: Any) -> "APIKeyHeaderAuth":
|
|
||||||
"""Return a new instance with additional (or overriding) validator kwargs."""
|
|
||||||
return APIKeyHeaderAuth(
|
|
||||||
self._name,
|
|
||||||
self._validator,
|
|
||||||
**{**self._kwargs, **kwargs},
|
|
||||||
)
|
|
||||||
@@ -1,121 +0,0 @@
|
|||||||
"""MultiAuth: combine multiple authentication sources into a single callable."""
|
|
||||||
|
|
||||||
import inspect
|
|
||||||
from typing import Any, cast
|
|
||||||
|
|
||||||
from fastapi import Request
|
|
||||||
from fastapi.security import SecurityScopes
|
|
||||||
|
|
||||||
from fastapi_toolsets.exceptions import UnauthorizedError
|
|
||||||
|
|
||||||
from ..abc import AuthSource
|
|
||||||
|
|
||||||
|
|
||||||
class MultiAuth:
|
|
||||||
"""Combine multiple authentication sources into a single callable.
|
|
||||||
|
|
||||||
Sources are tried in order; the first one whose
|
|
||||||
:meth:`~AuthSource.extract` returns a non-``None`` credential wins.
|
|
||||||
Its :meth:`~AuthSource.authenticate` is called and the result returned.
|
|
||||||
|
|
||||||
If a credential is found but the validator raises, the exception propagates
|
|
||||||
immediately — the remaining sources are **not** tried. This prevents
|
|
||||||
silent fallthrough on invalid credentials.
|
|
||||||
|
|
||||||
If no source provides a credential,
|
|
||||||
:class:`~fastapi_toolsets.exceptions.UnauthorizedError` is raised.
|
|
||||||
|
|
||||||
The :meth:`~AuthSource.extract` method of each source performs only
|
|
||||||
string matching (no I/O), so prefix-based dispatch is essentially free.
|
|
||||||
|
|
||||||
Any :class:`~AuthSource` subclass — including user-defined ones — can be
|
|
||||||
passed as a source.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
*sources: Auth source instances to try in order.
|
|
||||||
|
|
||||||
Example::
|
|
||||||
|
|
||||||
user_bearer = BearerTokenAuth(verify_user, prefix="user_")
|
|
||||||
org_bearer = BearerTokenAuth(verify_org, prefix="org_")
|
|
||||||
cookie = CookieAuth("session", verify_session)
|
|
||||||
|
|
||||||
multi = MultiAuth(user_bearer, org_bearer, cookie)
|
|
||||||
|
|
||||||
@app.get("/data")
|
|
||||||
async def data_route(user = Security(multi)):
|
|
||||||
return user
|
|
||||||
|
|
||||||
# Apply a shared requirement to all sources at once
|
|
||||||
@app.get("/admin")
|
|
||||||
async def admin_route(user = Security(multi.require(role=Role.ADMIN))):
|
|
||||||
return user
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *sources: AuthSource) -> None:
|
|
||||||
self._sources = sources
|
|
||||||
|
|
||||||
_sources = sources
|
|
||||||
|
|
||||||
async def _call(
|
|
||||||
request: Request,
|
|
||||||
security_scopes: SecurityScopes, # noqa: ARG001
|
|
||||||
**kwargs: Any, # noqa: ARG001 — absorbs scheme values injected by FastAPI
|
|
||||||
) -> Any:
|
|
||||||
for source in _sources:
|
|
||||||
credential = await source.extract(request)
|
|
||||||
if credential is not None:
|
|
||||||
return await source.authenticate(credential)
|
|
||||||
raise UnauthorizedError()
|
|
||||||
|
|
||||||
self._call_fn = _call
|
|
||||||
|
|
||||||
# Build a merged signature that includes the security-scheme Depends()
|
|
||||||
# parameters from every source so FastAPI registers them in OpenAPI docs.
|
|
||||||
seen: set[str] = {"request", "security_scopes"}
|
|
||||||
merged: list[inspect.Parameter] = [
|
|
||||||
inspect.Parameter(
|
|
||||||
"request",
|
|
||||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
||||||
annotation=Request,
|
|
||||||
),
|
|
||||||
inspect.Parameter(
|
|
||||||
"security_scopes",
|
|
||||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
||||||
annotation=SecurityScopes,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
for i, source in enumerate(sources):
|
|
||||||
for name, param in inspect.signature(source).parameters.items():
|
|
||||||
if name in seen:
|
|
||||||
continue
|
|
||||||
merged.append(param.replace(name=f"_s{i}_{name}"))
|
|
||||||
seen.add(name)
|
|
||||||
self.__signature__ = inspect.Signature(merged, return_annotation=Any)
|
|
||||||
|
|
||||||
async def __call__(self, **kwargs: Any) -> Any:
|
|
||||||
return await self._call_fn(**kwargs)
|
|
||||||
|
|
||||||
def require(self, **kwargs: Any) -> "MultiAuth":
|
|
||||||
"""Return a new :class:`MultiAuth` with kwargs forwarded to each source.
|
|
||||||
|
|
||||||
Calls ``.require(**kwargs)`` on every source that supports it. Sources
|
|
||||||
that do not implement ``.require()`` (e.g. custom :class:`~AuthSource`
|
|
||||||
subclasses) are passed through unchanged.
|
|
||||||
|
|
||||||
New kwargs are merged over each source's existing kwargs — new values
|
|
||||||
win on conflict::
|
|
||||||
|
|
||||||
multi = MultiAuth(bearer, cookie)
|
|
||||||
|
|
||||||
@app.get("/admin")
|
|
||||||
async def admin(user = Security(multi.require(role=Role.ADMIN))):
|
|
||||||
return user
|
|
||||||
"""
|
|
||||||
new_sources = tuple(
|
|
||||||
cast(Any, source).require(**kwargs)
|
|
||||||
if hasattr(source, "require")
|
|
||||||
else source
|
|
||||||
for source in self._sources
|
|
||||||
)
|
|
||||||
return MultiAuth(*new_sources)
|
|
||||||
@@ -35,6 +35,7 @@ from .conftest import (
|
|||||||
RoleCursorCrud,
|
RoleCursorCrud,
|
||||||
RoleRead,
|
RoleRead,
|
||||||
RoleUpdate,
|
RoleUpdate,
|
||||||
|
Tag,
|
||||||
TagCreate,
|
TagCreate,
|
||||||
TagCrud,
|
TagCrud,
|
||||||
User,
|
User,
|
||||||
@@ -294,6 +295,100 @@ class TestCrudGet:
|
|||||||
assert user.username == "active"
|
assert user.username == "active"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCrudGetOrNone:
|
||||||
|
"""Tests for CRUD get_or_none operations."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_returns_record_when_found(self, db_session: AsyncSession):
|
||||||
|
"""get_or_none returns the record when it exists."""
|
||||||
|
created = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||||
|
fetched = await RoleCrud.get_or_none(db_session, [Role.id == created.id])
|
||||||
|
|
||||||
|
assert fetched is not None
|
||||||
|
assert fetched.id == created.id
|
||||||
|
assert fetched.name == "admin"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_returns_none_when_not_found(self, db_session: AsyncSession):
|
||||||
|
"""get_or_none returns None instead of raising NotFoundError."""
|
||||||
|
result = await RoleCrud.get_or_none(db_session, [Role.id == uuid.uuid4()])
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_with_schema_returns_response_when_found(
|
||||||
|
self, db_session: AsyncSession
|
||||||
|
):
|
||||||
|
"""get_or_none with schema returns Response[schema] when found."""
|
||||||
|
from fastapi_toolsets.schemas import Response
|
||||||
|
|
||||||
|
created = await RoleCrud.create(db_session, RoleCreate(name="editor"))
|
||||||
|
result = await RoleCrud.get_or_none(
|
||||||
|
db_session, [Role.id == created.id], schema=RoleRead
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, Response)
|
||||||
|
assert isinstance(result.data, RoleRead)
|
||||||
|
assert result.data.name == "editor"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_with_schema_returns_none_when_not_found(
|
||||||
|
self, db_session: AsyncSession
|
||||||
|
):
|
||||||
|
"""get_or_none with schema returns None (not Response) when not found."""
|
||||||
|
result = await RoleCrud.get_or_none(
|
||||||
|
db_session, [Role.id == uuid.uuid4()], schema=RoleRead
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_with_load_options(self, db_session: AsyncSession):
|
||||||
|
"""get_or_none respects load_options."""
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="member"))
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="alice", email="alice@test.com", role_id=role.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
fetched = await UserCrud.get_or_none(
|
||||||
|
db_session,
|
||||||
|
[User.id == user.id],
|
||||||
|
load_options=[selectinload(User.role)],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert fetched is not None
|
||||||
|
assert fetched.role is not None
|
||||||
|
assert fetched.role.name == "member"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_with_join(self, db_session: AsyncSession):
|
||||||
|
"""get_or_none respects joins."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author", email="author@test.com")
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="Published", author_id=user.id, is_published=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
fetched = await UserCrud.get_or_none(
|
||||||
|
db_session,
|
||||||
|
[User.id == user.id, Post.is_published == True], # noqa: E712
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
)
|
||||||
|
assert fetched is not None
|
||||||
|
assert fetched.id == user.id
|
||||||
|
|
||||||
|
# Filter that matches no join — returns None
|
||||||
|
missing = await UserCrud.get_or_none(
|
||||||
|
db_session,
|
||||||
|
[User.id == user.id, Post.is_published == False], # noqa: E712
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
)
|
||||||
|
assert missing is None
|
||||||
|
|
||||||
|
|
||||||
class TestCrudFirst:
|
class TestCrudFirst:
|
||||||
"""Tests for CRUD first operations."""
|
"""Tests for CRUD first operations."""
|
||||||
|
|
||||||
@@ -321,6 +416,38 @@ class TestCrudFirst:
|
|||||||
role = await RoleCrud.first(db_session)
|
role = await RoleCrud.first(db_session)
|
||||||
assert role is not None
|
assert role is not None
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_first_with_schema(self, db_session: AsyncSession):
|
||||||
|
"""First with schema returns a Response wrapping the serialized record."""
|
||||||
|
await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||||
|
|
||||||
|
result = await RoleCrud.first(
|
||||||
|
db_session, [Role.name == "admin"], schema=RoleRead
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.data is not None
|
||||||
|
assert result.data.name == "admin"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_first_with_schema_not_found(self, db_session: AsyncSession):
|
||||||
|
"""First with schema returns None when no record matches."""
|
||||||
|
result = await RoleCrud.first(
|
||||||
|
db_session, [Role.name == "ghost"], schema=RoleRead
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_first_with_for_update(self, db_session: AsyncSession):
|
||||||
|
"""First with with_for_update locks the row."""
|
||||||
|
await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||||
|
|
||||||
|
role = await RoleCrud.first(
|
||||||
|
db_session, [Role.name == "admin"], with_for_update=True
|
||||||
|
)
|
||||||
|
assert role is not None
|
||||||
|
assert role.name == "admin"
|
||||||
|
|
||||||
|
|
||||||
class TestCrudGetMulti:
|
class TestCrudGetMulti:
|
||||||
"""Tests for CRUD get_multi operations."""
|
"""Tests for CRUD get_multi operations."""
|
||||||
@@ -480,6 +607,69 @@ class TestCrudDelete:
|
|||||||
assert result.data is None
|
assert result.data is None
|
||||||
assert await RoleCrud.first(db_session, [Role.id == role.id]) is None
|
assert await RoleCrud.first(db_session, [Role.id == role.id]) is None
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_delete_m2m_cascade(self, db_session: AsyncSession):
|
||||||
|
"""Deleting a record with M2M relationships cleans up the association table."""
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author", email="author@test.com")
|
||||||
|
)
|
||||||
|
tag1 = await TagCrud.create(db_session, TagCreate(name="python"))
|
||||||
|
tag2 = await TagCrud.create(db_session, TagCreate(name="fastapi"))
|
||||||
|
|
||||||
|
post = await PostM2MCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostM2MCreate(
|
||||||
|
title="M2M Delete Test",
|
||||||
|
author_id=user.id,
|
||||||
|
tag_ids=[tag1.id, tag2.id],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
await PostM2MCrud.delete(db_session, [Post.id == post.id])
|
||||||
|
|
||||||
|
# Post is gone
|
||||||
|
assert await PostCrud.first(db_session, [Post.id == post.id]) is None
|
||||||
|
|
||||||
|
# Association rows are gone — tags themselves must still exist
|
||||||
|
assert await TagCrud.first(db_session, [Tag.id == tag1.id]) is not None
|
||||||
|
assert await TagCrud.first(db_session, [Tag.id == tag2.id]) is not None
|
||||||
|
|
||||||
|
# No orphaned rows in post_tags
|
||||||
|
result = await db_session.execute(
|
||||||
|
text("SELECT COUNT(*) FROM post_tags WHERE post_id = :pid").bindparams(
|
||||||
|
pid=post.id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert result.scalar() == 0
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_delete_m2m_does_not_delete_related_records(
|
||||||
|
self, db_session: AsyncSession
|
||||||
|
):
|
||||||
|
"""Deleting a post with M2M tags must not delete the tags themselves."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author2", email="author2@test.com")
|
||||||
|
)
|
||||||
|
tag = await TagCrud.create(db_session, TagCreate(name="shared_tag"))
|
||||||
|
|
||||||
|
post1 = await PostM2MCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostM2MCreate(title="Post 1", author_id=user.id, tag_ids=[tag.id]),
|
||||||
|
)
|
||||||
|
post2 = await PostM2MCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostM2MCreate(title="Post 2", author_id=user.id, tag_ids=[tag.id]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete only post1
|
||||||
|
await PostM2MCrud.delete(db_session, [Post.id == post1.id])
|
||||||
|
|
||||||
|
# Tag and post2 still exist
|
||||||
|
assert await TagCrud.first(db_session, [Tag.id == tag.id]) is not None
|
||||||
|
assert await PostCrud.first(db_session, [Post.id == post2.id]) is not None
|
||||||
|
|
||||||
|
|
||||||
class TestCrudExists:
|
class TestCrudExists:
|
||||||
"""Tests for CRUD exists operations."""
|
"""Tests for CRUD exists operations."""
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
2
uv.lock
generated
2
uv.lock
generated
@@ -251,7 +251,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fastapi-toolsets"
|
name = "fastapi-toolsets"
|
||||||
version = "2.1.0"
|
version = "2.2.0"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "asyncpg" },
|
{ name = "asyncpg" },
|
||||||
|
|||||||
Reference in New Issue
Block a user