mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-16 06:36:26 +02:00
Compare commits
174 Commits
v0.1.0
...
6e999985c0
| Author | SHA1 | Date | |
|---|---|---|---|
|
6e999985c0
|
|||
|
c3d1fe977d
|
|||
|
92036d6b88
|
|||
|
ba6c267897
|
|||
|
|
e38d8d2d4f | ||
|
9b74f162ab
|
|||
|
|
ab125c6ea1 | ||
|
|
e388e26858 | ||
|
|
04da241294 | ||
|
|
bbe63edc46 | ||
|
|
0b17c77dee | ||
|
|
bce71bfd42 | ||
|
|
2f1eb4d468 | ||
|
|
1f06eab11d | ||
|
|
fac9aa6f60 | ||
|
|
f310466697 | ||
|
|
32059dcb02 | ||
|
|
f027981e80 | ||
|
|
5c1487c24a | ||
|
|
ebaa61525f | ||
|
|
4829cfba73 | ||
|
|
9ca2da4213 | ||
|
|
0b3f097012 | ||
|
|
1890d696bf | ||
|
|
104285c6e5 | ||
|
|
f5afbbe37f | ||
|
|
f4698bea8a | ||
|
|
5215b921ae | ||
|
9dad59e25d
|
|||
|
|
29326ab532 | ||
|
|
04afef7e33 | ||
|
|
666c621fda | ||
|
460b760fa4
|
|||
|
|
65d0b0e0b1 | ||
|
|
2d49cd32db | ||
|
|
a5dd756d87 | ||
|
|
781cfb66c9 | ||
|
|
91b84f8146 | ||
|
|
396e381ac3 | ||
|
|
b4eb4c1ca9 | ||
|
c90717754f
|
|||
|
337985ef38
|
|||
|
|
b5e6dfe6fe | ||
|
|
6681b7ade7 | ||
|
|
6981c33dc8 | ||
|
|
0c7a99039c | ||
|
|
bcb5b0bfda | ||
|
100e1c1aa9
|
|||
|
|
db6c7a565f | ||
|
|
768e405554 | ||
|
|
f0223ebde4 | ||
|
|
f8c9bf69fe | ||
|
6d6fae5538
|
|||
|
|
fc9cd1f034 | ||
|
|
f82225f995 | ||
|
|
e62612a93a | ||
|
|
56f0ea291e | ||
|
|
ee896009ee | ||
|
|
65bf928e12 | ||
|
|
2e9c6c0c90 | ||
|
2c494fcd17
|
|||
|
|
fd7269a372 | ||
|
|
c863744012 | ||
|
|
aedcbf4e04 | ||
|
|
19c013bdec | ||
|
|
81407c3038 | ||
|
|
0fb00d44da | ||
|
|
19232d3436 | ||
|
1eafcb3873
|
|||
|
|
0d67fbb58d | ||
|
|
a59f098930 | ||
|
|
96e34ba8af | ||
|
|
26d649791f | ||
|
dde5183e68
|
|||
|
|
e4250a9910 | ||
|
|
4800941934 | ||
|
0cc21d2012
|
|||
|
|
a3245d50f0 | ||
|
|
baebf022f6 | ||
|
|
96d445e3f3 | ||
|
|
80306e1af3 | ||
|
|
fd999b63f1 | ||
|
|
c0f352b914 | ||
|
c4c760484b
|
|||
|
|
432e0722e0 | ||
|
|
e732e54518 | ||
|
|
05b5a2c876 | ||
|
|
4a020c56d1 | ||
|
56d365d14b
|
|||
|
|
a257d85d45 | ||
|
117675d02f
|
|||
|
|
d7ad7308c5 | ||
|
8d57bf9525
|
|||
|
|
5a08ec2f57 | ||
|
|
433dc55fcd | ||
|
|
0b2abd8c43 | ||
|
|
07c99be89b | ||
|
|
9b75cc7dfc | ||
|
|
6144b383eb | ||
|
7ec407834a
|
|||
|
|
7da34f33a2 | ||
|
8c8911fb27
|
|||
|
|
c0c3b38054 | ||
|
e17d385910
|
|||
|
|
6cf7df55ef | ||
|
|
7482bc5dad | ||
|
|
9d07dfea85 | ||
|
|
31678935aa | ||
|
|
823a0b3e36 | ||
|
1591cd3d64
|
|||
|
|
6714ceeb92 | ||
|
|
73fae04333 | ||
|
|
32ed36e102 | ||
|
|
48567310bc | ||
|
|
de51ed4675 | ||
|
|
794767edbb | ||
|
|
9c136f05bb | ||
|
3299a439fe
|
|||
|
|
d5b22a72fd | ||
|
|
c32f2e18be | ||
|
d971261f98
|
|||
|
|
74a54b7396 | ||
|
|
19805ab376 | ||
|
|
d4498e2063 | ||
| f59c1a17e2 | |||
|
|
8982ba18e3 | ||
|
|
71fe6f478f | ||
|
|
1cfbf14986 | ||
|
|
e3ff535b7e | ||
|
|
8825c772ce | ||
|
|
c8c263ca8f | ||
|
2020fa2f92
|
|||
|
|
1ea316bef4 | ||
|
|
ced1a655f2 | ||
|
|
290b2a06ec | ||
|
|
baa9711665 | ||
|
d526969d0e
|
|||
|
|
e24153053e | ||
|
348ed4c148
|
|||
|
bd6e90de1b
|
|||
|
|
4404fb3df9 | ||
|
|
f68793fbdb | ||
|
|
3a69c3c788 | ||
|
e861a0a49a
|
|||
|
|
cb2cf572e0 | ||
|
494869a172
|
|||
|
|
e0bc93096d | ||
|
1ff94eb9d3
|
|||
|
|
97ab10edcd | ||
|
|
3ff7ff18bb | ||
|
0f50c8a0f0
|
|||
|
|
691fb78fda | ||
|
|
34ef4da317 | ||
|
|
8c287b3ce7 | ||
|
54f5479c24
|
|||
|
|
f467754df1 | ||
|
b57ce40b05
|
|||
|
5264631550
|
|||
|
a76f7c439d
|
|||
|
|
d14551781c | ||
|
|
577e087321 | ||
|
aa72dc2eb5
|
|||
|
|
1a98e36909 | ||
|
ba5180a73b
|
|||
|
a9f486d905
|
|||
|
53e80cd0d5
|
|||
|
|
45001767aa | ||
|
|
cd551b6bff | ||
|
|
718a12be28 | ||
|
|
fa16bf1bff | ||
|
|
c4a227f9fc | ||
| fe1ccabdd8 | |||
|
|
9e7473fbf5 | ||
|
|
d9d7f60e8e |
14
.github/dependabot.yml
vendored
Normal file
14
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
version: 2
|
||||||
|
updates:
|
||||||
|
- package-ecosystem: "github-actions"
|
||||||
|
directory: "/"
|
||||||
|
schedule:
|
||||||
|
interval: "weekly"
|
||||||
|
commit-message:
|
||||||
|
prefix: ⬆
|
||||||
|
- package-ecosystem: "uv"
|
||||||
|
directory: "/"
|
||||||
|
schedule:
|
||||||
|
interval: "weekly"
|
||||||
|
commit-message:
|
||||||
|
prefix: ⬆
|
||||||
8
.github/workflows/build-release.yml
vendored
8
.github/workflows/build-release.yml
vendored
@@ -11,16 +11,16 @@ jobs:
|
|||||||
permissions:
|
permissions:
|
||||||
id-token: write
|
id-token: write
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@v5
|
uses: astral-sh/setup-uv@v7
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
run: uv python install 3.13
|
run: uv python install 3.14
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: uv sync
|
run: uv sync --group dev
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
run: uv build
|
run: uv build
|
||||||
|
|||||||
34
.github/workflows/ci.yml
vendored
34
.github/workflows/ci.yml
vendored
@@ -15,16 +15,16 @@ jobs:
|
|||||||
name: Lint (Ruff)
|
name: Lint (Ruff)
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@v5
|
uses: astral-sh/setup-uv@v7
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
run: uv python install 3.13
|
run: uv python install 3.13
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: uv sync --extra dev
|
run: uv sync --group dev
|
||||||
|
|
||||||
- name: Run Ruff linter
|
- name: Run Ruff linter
|
||||||
run: uv run ruff check .
|
run: uv run ruff check .
|
||||||
@@ -36,16 +36,16 @@ jobs:
|
|||||||
name: Type Check (ty)
|
name: Type Check (ty)
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@v5
|
uses: astral-sh/setup-uv@v7
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
run: uv python install 3.13
|
run: uv python install 3.13
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: uv sync --extra dev
|
run: uv sync --group dev
|
||||||
|
|
||||||
- name: Run ty
|
- name: Run ty
|
||||||
run: uv run ty check
|
run: uv run ty check
|
||||||
@@ -56,7 +56,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.11", "3.12", "3.13"]
|
python-version: ["3.11", "3.12", "3.13", "3.14"]
|
||||||
|
|
||||||
services:
|
services:
|
||||||
postgres:
|
postgres:
|
||||||
@@ -74,27 +74,35 @@ jobs:
|
|||||||
--health-retries 5
|
--health-retries 5
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@v5
|
uses: astral-sh/setup-uv@v7
|
||||||
|
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
run: uv python install ${{ matrix.python-version }}
|
run: uv python install ${{ matrix.python-version }}
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: uv sync --extra dev
|
run: uv sync --group dev
|
||||||
|
|
||||||
- name: Run tests with coverage
|
- name: Run tests with coverage
|
||||||
env:
|
env:
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@localhost:5432/test_db
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@localhost:5432/test_db
|
||||||
run: |
|
run: |
|
||||||
uv run pytest --cov --cov-report=xml --cov-report=term-missing
|
uv run pytest --cov --cov-report=xml --cov-report=term-missing --junitxml=junit.xml -o junit_family=legacy
|
||||||
|
|
||||||
- name: Upload coverage to Codecov
|
- name: Upload coverage to Codecov
|
||||||
if: matrix.python-version == '3.13'
|
if: matrix.python-version == '3.14'
|
||||||
uses: codecov/codecov-action@v5
|
uses: codecov/codecov-action@v6
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
|
report_type: coverage
|
||||||
files: ./coverage.xml
|
files: ./coverage.xml
|
||||||
fail_ci_if_error: false
|
fail_ci_if_error: false
|
||||||
|
|
||||||
|
- name: Upload test results to Codecov
|
||||||
|
if: matrix.python-version == '3.14'
|
||||||
|
uses: codecov/codecov-action@v6
|
||||||
|
with:
|
||||||
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
|
report_type: test_results
|
||||||
|
|||||||
53
.github/workflows/docs.yml
vendored
Normal file
53
.github/workflows/docs.yml
vendored
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
name: Documentation
|
||||||
|
|
||||||
|
on:
|
||||||
|
release:
|
||||||
|
types: [published]
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
deploy:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v7
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
run: uv python install 3.13
|
||||||
|
|
||||||
|
- run: uv sync --group dev
|
||||||
|
|
||||||
|
- name: Configure git
|
||||||
|
run: |
|
||||||
|
git config user.name "github-actions[bot]"
|
||||||
|
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||||
|
|
||||||
|
- name: Deploy documentation
|
||||||
|
run: |
|
||||||
|
VERSION=${GITHUB_REF_NAME#v}
|
||||||
|
MAJOR=$(echo "$VERSION" | cut -d. -f1)
|
||||||
|
DEPLOY_VERSION="v$(echo "$VERSION" | cut -d. -f1-2)"
|
||||||
|
|
||||||
|
# On new major: keep only the latest feature version of the previous major
|
||||||
|
PREV_MAJOR=$((MAJOR - 1))
|
||||||
|
OLD_FEATURE_VERSIONS=$(uv run mike list 2>/dev/null | grep -oE "^v${PREV_MAJOR}\.[0-9]+" || true)
|
||||||
|
|
||||||
|
if [ -n "$OLD_FEATURE_VERSIONS" ]; then
|
||||||
|
LATEST_PREV=$(echo "$OLD_FEATURE_VERSIONS" | sort -t. -k2 -n | tail -1)
|
||||||
|
echo "$OLD_FEATURE_VERSIONS" | while read -r OLD_V; do
|
||||||
|
if [ "$OLD_V" != "$LATEST_PREV" ]; then
|
||||||
|
echo "Deleting $OLD_V"
|
||||||
|
uv run mike delete "$OLD_V"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
uv run mike deploy --update-aliases "$DEPLOY_VERSION" stable
|
||||||
|
uv run mike set-default stable
|
||||||
|
git push origin gh-pages
|
||||||
34
.pre-commit-config.yaml
Normal file
34
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
# See https://pre-commit.com for more information
|
||||||
|
# See https://pre-commit.com/hooks.html for more hooks
|
||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v6.0.0
|
||||||
|
hooks:
|
||||||
|
- id: check-added-large-files
|
||||||
|
args: ["--maxkb=750"]
|
||||||
|
exclude: ^uv.lock$
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: trailing-whitespace
|
||||||
|
|
||||||
|
- repo: local
|
||||||
|
hooks:
|
||||||
|
- id: local-ruff-check
|
||||||
|
name: ruff check
|
||||||
|
entry: uv run ruff check --force-exclude --fix --exit-non-zero-on-fix .
|
||||||
|
require_serial: true
|
||||||
|
language: unsupported
|
||||||
|
types: [python]
|
||||||
|
|
||||||
|
- id: local-ruff-format
|
||||||
|
name: ruff format
|
||||||
|
entry: uv run ruff format --force-exclude --exit-non-zero-on-format .
|
||||||
|
require_serial: true
|
||||||
|
language: unsupported
|
||||||
|
types: [python]
|
||||||
|
|
||||||
|
- id: local-ty
|
||||||
|
name: ty check
|
||||||
|
entry: uv run ty check
|
||||||
|
require_serial: true
|
||||||
|
language: unsupported
|
||||||
|
pass_filenames: false
|
||||||
@@ -1 +1 @@
|
|||||||
3.13
|
3.14
|
||||||
|
|||||||
38
README.md
38
README.md
@@ -1,6 +1,6 @@
|
|||||||
# FastAPI Toolsets
|
# FastAPI Toolsets
|
||||||
|
|
||||||
FastAPI Toolsets provides production-ready utilities for FastAPI applications built with async SQLAlchemy and PostgreSQL. It includes generic CRUD operations, a fixture system with dependency resolution, a Django-like CLI, standardized API responses, and structured exception handling with automatic OpenAPI documentation.
|
A modular collection of production-ready utilities for FastAPI. Install only what you need — from async CRUD and database helpers to CLI tooling, Prometheus metrics, and pytest fixtures. Each module is independently installable via optional extras, keeping your dependency footprint minimal.
|
||||||
|
|
||||||
[](https://github.com/d3vyce/fastapi-toolsets/actions/workflows/ci.yml)
|
[](https://github.com/d3vyce/fastapi-toolsets/actions/workflows/ci.yml)
|
||||||
[](https://codecov.io/gh/d3vyce/fastapi-toolsets)
|
[](https://codecov.io/gh/d3vyce/fastapi-toolsets)
|
||||||
@@ -20,17 +20,45 @@ FastAPI Toolsets provides production-ready utilities for FastAPI applications bu
|
|||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
|
The base package includes the core modules (CRUD, database, schemas, exceptions, fixtures, dependencies, model mixins, logging):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv add fastapi-toolsets
|
uv add fastapi-toolsets
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Install only the extras you need:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add "fastapi-toolsets[cli]"
|
||||||
|
uv add "fastapi-toolsets[metrics]"
|
||||||
|
uv add "fastapi-toolsets[pytest]"
|
||||||
|
```
|
||||||
|
|
||||||
|
Or install everything:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add "fastapi-toolsets[all]"
|
||||||
|
```
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- **CRUD**: Generic async CRUD operations with `CrudFactory`
|
### Core
|
||||||
- **Fixtures**: Fixture system with dependency management, context support and pytest integration
|
|
||||||
- **CLI**: Django-like command-line interface for fixtures and custom commands
|
- **CRUD**: Generic async CRUD operations with `CrudFactory`, built-in full-text/faceted search and Offset/Cursor pagination.
|
||||||
- **Standardized API Responses**: Consistent response format across your API
|
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
|
||||||
|
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
|
||||||
|
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration
|
||||||
|
- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`)
|
||||||
|
- **Lifecycle Events**: Post-commit event system (`EventSession`, `listens_for`) that dispatches async/sync callbacks for insert, update, and delete operations
|
||||||
|
- **Standardized API Responses**: Consistent response format with `Response`, `ErrorResponse`, `PaginatedResponse`, `CursorPaginatedResponse` and `OffsetPaginatedResponse`.
|
||||||
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
|
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
|
||||||
|
- **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger`
|
||||||
|
|
||||||
|
### Optional
|
||||||
|
|
||||||
|
- **CLI**: Django-like command-line interface with fixture management and custom commands support
|
||||||
|
- **Metrics**: Prometheus metrics endpoint with provider/collector registry
|
||||||
|
- **Pytest Helpers**: Async test client, database session management, `pytest-xdist` support, and table cleanup utilities
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
|
|||||||
1
docs/examples/authentication.md
Normal file
1
docs/examples/authentication.md
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Authentication
|
||||||
185
docs/examples/pagination-search.md
Normal file
185
docs/examples/pagination-search.md
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
# Pagination & search
|
||||||
|
|
||||||
|
This example builds an articles listing endpoint that supports **offset pagination**, **cursor pagination**, **full-text search**, **faceted filtering**, and **sorting** — all from a single `CrudFactory` definition.
|
||||||
|
|
||||||
|
## Models
|
||||||
|
|
||||||
|
```python title="models.py"
|
||||||
|
--8<-- "docs_src/examples/pagination_search/models.py"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Schemas
|
||||||
|
|
||||||
|
```python title="schemas.py"
|
||||||
|
--8<-- "docs_src/examples/pagination_search/schemas.py"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Crud
|
||||||
|
|
||||||
|
Declare `searchable_fields`, `facet_fields`, and `order_fields` once on [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory). All endpoints built from this class share the same defaults and can override them per call.
|
||||||
|
|
||||||
|
```python title="crud.py"
|
||||||
|
--8<-- "docs_src/examples/pagination_search/crud.py"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Session dependency
|
||||||
|
|
||||||
|
```python title="db.py"
|
||||||
|
--8<-- "docs_src/examples/pagination_search/db.py"
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! info "Deploy a Postgres DB with docker"
|
||||||
|
```bash
|
||||||
|
docker run -d --name postgres -e POSTGRES_USER=postgres -e POSTGRES_PASSWORD=postgres -e POSTGRES_DB=postgres -p 5432:5432 postgres:18-alpine
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## App
|
||||||
|
|
||||||
|
```python title="app.py"
|
||||||
|
--8<-- "docs_src/examples/pagination_search/app.py"
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Routes
|
||||||
|
|
||||||
|
```python title="routes.py:1:16"
|
||||||
|
--8<-- "docs_src/examples/pagination_search/routes.py:1:16"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Offset pagination
|
||||||
|
|
||||||
|
Best for admin panels or any UI that needs a total item count and numbered pages.
|
||||||
|
|
||||||
|
```python title="routes.py:19:37"
|
||||||
|
--8<-- "docs_src/examples/pagination_search/routes.py:19:37"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example request**
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /articles/offset?page=2&items_per_page=10&search=fastapi&status=published&order_by=title&order=asc
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example response**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"pagination_type": "offset",
|
||||||
|
"data": [
|
||||||
|
{ "id": "3f47ac69-...", "title": "FastAPI tips", "status": "published", ... }
|
||||||
|
],
|
||||||
|
"pagination": {
|
||||||
|
"total_count": 42,
|
||||||
|
"pages": 5,
|
||||||
|
"page": 2,
|
||||||
|
"items_per_page": 10,
|
||||||
|
"has_more": true
|
||||||
|
},
|
||||||
|
"filter_attributes": {
|
||||||
|
"status": ["archived", "draft", "published"],
|
||||||
|
"name": ["backend", "frontend", "python"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
`filter_attributes` always reflects the values visible **after** applying the active filters. Use it to populate filter dropdowns on the client.
|
||||||
|
|
||||||
|
To skip the `COUNT(*)` query for better performance on large tables, pass `include_total=False`. `pagination.total_count` will be `null` in the response, while `has_more` remains accurate.
|
||||||
|
|
||||||
|
### Cursor pagination
|
||||||
|
|
||||||
|
Best for feeds, infinite scroll, or any high-throughput API where offset performance degrades.
|
||||||
|
|
||||||
|
```python title="routes.py:40:58"
|
||||||
|
--8<-- "docs_src/examples/pagination_search/routes.py:40:58"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example request**
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /articles/cursor?items_per_page=10&status=published&order_by=created_at&order=desc
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example response**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"pagination_type": "cursor",
|
||||||
|
"data": [
|
||||||
|
{ "id": "3f47ac69-...", "title": "FastAPI tips", "status": "published", ... }
|
||||||
|
],
|
||||||
|
"pagination": {
|
||||||
|
"next_cursor": "eyJ2YWx1ZSI6ICIzZjQ3YWM2OS0uLi4ifQ==",
|
||||||
|
"prev_cursor": null,
|
||||||
|
"items_per_page": 10,
|
||||||
|
"has_more": true
|
||||||
|
},
|
||||||
|
"filter_attributes": {
|
||||||
|
"status": ["published"],
|
||||||
|
"name": ["backend", "python"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Pass `next_cursor` as the `cursor` query parameter on the next request to advance to the next page.
|
||||||
|
|
||||||
|
### Unified endpoint (both strategies)
|
||||||
|
|
||||||
|
!!! info "Added in `v2.3.0`"
|
||||||
|
|
||||||
|
[`paginate()`](../module/crud.md#unified-paginate--both-strategies-on-one-endpoint) lets a single endpoint support both strategies via a `pagination_type` query parameter. The `pagination_type` field in the response acts as a discriminator for frontend tooling.
|
||||||
|
|
||||||
|
```python title="routes.py:61:79"
|
||||||
|
--8<-- "docs_src/examples/pagination_search/routes.py:61:79"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Offset request** (default)
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /articles/?pagination_type=offset&page=1&items_per_page=10
|
||||||
|
```
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"pagination_type": "offset",
|
||||||
|
"data": ["..."],
|
||||||
|
"pagination": { "total_count": 42, "pages": 5, "page": 1, "items_per_page": 10, "has_more": true }
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Cursor request**
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /articles/?pagination_type=cursor&items_per_page=10
|
||||||
|
GET /articles/?pagination_type=cursor&items_per_page=10&cursor=eyJ2YWx1ZSI6...
|
||||||
|
```
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"pagination_type": "cursor",
|
||||||
|
"data": ["..."],
|
||||||
|
"pagination": { "next_cursor": "eyJ2YWx1ZSI6...", "prev_cursor": null, "items_per_page": 10, "has_more": true }
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Search behaviour
|
||||||
|
|
||||||
|
Both endpoints inherit the same `searchable_fields` declared on `ArticleCrud`:
|
||||||
|
|
||||||
|
Search is **case-insensitive** and uses a `LIKE %query%` pattern. Pass a [`SearchConfig`](../reference/crud.md#fastapi_toolsets.crud.search.SearchConfig) instead of a plain string to control case sensitivity or switch to `match_mode="all"` (AND across all fields instead of OR).
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.crud import SearchConfig
|
||||||
|
|
||||||
|
# Both title AND body must contain "fastapi"
|
||||||
|
result = await ArticleCrud.offset_paginate(
|
||||||
|
session,
|
||||||
|
search=SearchConfig(query="fastapi", case_sensitive=True, match_mode="all"),
|
||||||
|
search_fields=[Article.title, Article.body],
|
||||||
|
)
|
||||||
|
```
|
||||||
69
docs/index.md
Normal file
69
docs/index.md
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
# FastAPI Toolsets
|
||||||
|
|
||||||
|
A modular collection of production-ready utilities for FastAPI. Install only what you need — from async CRUD and database helpers to CLI tooling, Prometheus metrics, and pytest fixtures. Each module is independently installable via optional extras, keeping your dependency footprint minimal.
|
||||||
|
|
||||||
|
[](https://github.com/d3vyce/fastapi-toolsets/actions/workflows/ci.yml)
|
||||||
|
[](https://codecov.io/gh/d3vyce/fastapi-toolsets)
|
||||||
|
[](https://github.com/astral-sh/ty)
|
||||||
|
[](https://github.com/astral-sh/uv)
|
||||||
|
[](https://github.com/astral-sh/ruff)
|
||||||
|
[](https://www.python.org/downloads/)
|
||||||
|
[](https://opensource.org/licenses/MIT)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Documentation**: [https://fastapi-toolsets.d3vyce.fr](https://fastapi-toolsets.d3vyce.fr)
|
||||||
|
|
||||||
|
**Source Code**: [https://github.com/d3vyce/fastapi-toolsets](https://github.com/d3vyce/fastapi-toolsets)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
The base package includes the core modules (CRUD, database, schemas, exceptions, fixtures, dependencies, model mixins, logging):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add fastapi-toolsets
|
||||||
|
```
|
||||||
|
|
||||||
|
Install only the extras you need:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add "fastapi-toolsets[cli]"
|
||||||
|
uv add "fastapi-toolsets[metrics]"
|
||||||
|
uv add "fastapi-toolsets[pytest]"
|
||||||
|
```
|
||||||
|
|
||||||
|
Or install everything:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add "fastapi-toolsets[all]"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
### Core
|
||||||
|
|
||||||
|
- **CRUD**: Generic async CRUD operations with `CrudFactory`, built-in full-text/faceted search and Offset/Cursor pagination.
|
||||||
|
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
|
||||||
|
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
|
||||||
|
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration
|
||||||
|
- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`).
|
||||||
|
- **Lifecycle Events**: Post-commit event system (`EventSession`, `listens_for`) that dispatches async/sync callbacks for insert, update, and delete operations.
|
||||||
|
- **Standardized API Responses**: Consistent response format with `Response`, `ErrorResponse`, `PaginatedResponse`, `CursorPaginatedResponse` and `OffsetPaginatedResponse`.
|
||||||
|
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
|
||||||
|
- **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger`
|
||||||
|
|
||||||
|
### Optional
|
||||||
|
|
||||||
|
- **CLI**: Django-like command-line interface with fixture management and custom commands support
|
||||||
|
- **Metrics**: Prometheus metrics endpoint with provider/collector registry
|
||||||
|
- **Pytest Helpers**: Async test client, database session management, `pytest-xdist` support, and table cleanup utilities
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT License - see [LICENSE](LICENSE) for details.
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Contributions are welcome! Please feel free to submit issues and pull requests.
|
||||||
137
docs/migration/v2.md
Normal file
137
docs/migration/v2.md
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
# Migrating to v2.0
|
||||||
|
|
||||||
|
This page covers every breaking change introduced in **v2.0** and the steps required to update your code.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## CRUD
|
||||||
|
|
||||||
|
### `schema` is now required in `offset_paginate()` and `cursor_paginate()`
|
||||||
|
|
||||||
|
Calls that omit `schema` will now raise a `TypeError` at runtime.
|
||||||
|
|
||||||
|
Previously `schema` was optional; omitting it returned raw SQLAlchemy model instances inside the response. It is now a required keyword argument and the response always contains serialized schema instances.
|
||||||
|
|
||||||
|
=== "Before (`v1`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
# schema omitted — returned raw model instances
|
||||||
|
result = await UserCrud.offset_paginate(session=session, page=1)
|
||||||
|
result = await UserCrud.cursor_paginate(session=session, cursor=token)
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Now (`v2`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
result = await UserCrud.offset_paginate(session=session, page=1, schema=UserRead)
|
||||||
|
result = await UserCrud.cursor_paginate(session=session, cursor=token, schema=UserRead)
|
||||||
|
```
|
||||||
|
|
||||||
|
### `as_response` removed from `create()`, `get()`, and `update()`
|
||||||
|
|
||||||
|
Passing `as_response` to these methods will raise a `TypeError` at runtime.
|
||||||
|
|
||||||
|
The `as_response=True` shorthand is replaced by passing a `schema` directly. The return value is a `Response[schema]` when `schema` is provided, or the raw model instance when it is not.
|
||||||
|
|
||||||
|
=== "Before (`v1`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
user = await UserCrud.create(session=session, obj=data, as_response=True)
|
||||||
|
user = await UserCrud.get(session=session, filters=filters, as_response=True)
|
||||||
|
user = await UserCrud.update(session=session, obj=data, filters, as_response=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Now (`v2`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
user = await UserCrud.create(session=session, obj=data, schema=UserRead)
|
||||||
|
user = await UserCrud.get(session=session, filters=filters, schema=UserRead)
|
||||||
|
user = await UserCrud.update(session=session, obj=data, filters, schema=UserRead)
|
||||||
|
```
|
||||||
|
|
||||||
|
### `delete()`: `as_response` renamed and return type changed
|
||||||
|
|
||||||
|
`as_response` is gone, and the plain (non-response) call no longer returns `True`.
|
||||||
|
|
||||||
|
Two changes were made to `delete()`:
|
||||||
|
|
||||||
|
1. The `as_response` parameter is renamed to `return_response`.
|
||||||
|
2. When called without `return_response=True`, the method now returns `None` on success instead of `True`.
|
||||||
|
|
||||||
|
=== "Before (`v1`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
ok = await UserCrud.delete(session=session, filters=filters)
|
||||||
|
if ok: # True on success
|
||||||
|
...
|
||||||
|
|
||||||
|
response = await UserCrud.delete(session=session, filters=filters, as_response=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Now (`v2`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
await UserCrud.delete(session=session, filters=filters) # returns None
|
||||||
|
|
||||||
|
response = await UserCrud.delete(session=session, filters=filters, return_response=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
### `paginate()` alias removed
|
||||||
|
|
||||||
|
Any call to `crud.paginate(...)` will raise `AttributeError` at runtime.
|
||||||
|
|
||||||
|
The `paginate` shorthand was an alias for `offset_paginate`. It has been removed; call `offset_paginate` directly.
|
||||||
|
|
||||||
|
=== "Before (`v1`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
result = await UserCrud.paginate(session=session, page=2, items_per_page=20, schema=UserRead)
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Now (`v2`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
result = await UserCrud.offset_paginate(session=session, page=2, items_per_page=20, schema=UserRead)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Exceptions
|
||||||
|
|
||||||
|
### Missing `api_error` raises `TypeError` at class definition time
|
||||||
|
|
||||||
|
Unfinished or stub exception subclasses that previously compiled fine will now fail on import.
|
||||||
|
|
||||||
|
In `v1`, a subclass without `api_error` would only fail when the exception was raised. In `v2`, `__init_subclass__` validates this at class definition time.
|
||||||
|
|
||||||
|
=== "Before (`v1`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MyError(ApiException):
|
||||||
|
pass # fine until raised
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Now (`v2`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MyError(ApiException):
|
||||||
|
pass # TypeError: MyError must define an 'api_error' class attribute.
|
||||||
|
```
|
||||||
|
|
||||||
|
For shared base classes that are not meant to be raised directly, use `abstract=True`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class BillingError(ApiException, abstract=True):
|
||||||
|
"""Base for all billing-related errors — not raised directly."""
|
||||||
|
|
||||||
|
class PaymentRequiredError(BillingError):
|
||||||
|
api_error = ApiError(code=402, msg="Payment Required", desc="...", err_code="BILLING-402")
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Schemas
|
||||||
|
|
||||||
|
### `Pagination` alias removed
|
||||||
|
|
||||||
|
`Pagination` was already deprecated in `v1` and is fully removed in `v2`, you now need to use [`OffsetPagination`](../reference/schemas.md#fastapi_toolsets.schemas.OffsetPagination) or [`CursorPagination`](../reference/schemas.md#fastapi_toolsets.schemas.CursorPagination).
|
||||||
180
docs/migration/v3.md
Normal file
180
docs/migration/v3.md
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
# Migrating to v3.0
|
||||||
|
|
||||||
|
This page covers every breaking change introduced in **v3.0** and the steps required to update your code.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## CRUD
|
||||||
|
|
||||||
|
### Facet keys now always use the full relationship chain
|
||||||
|
|
||||||
|
In `v2`, relationship facet fields used only the terminal column key (e.g. `"name"` for `Role.name`) and only prepended the relationship name when two facet fields shared the same column key. In `v3`, facet keys **always** include the full relationship chain joined by `__`, regardless of collisions.
|
||||||
|
|
||||||
|
=== "Before (`v2`)"
|
||||||
|
|
||||||
|
```
|
||||||
|
User.status -> status
|
||||||
|
(User.role, Role.name) -> name
|
||||||
|
(User.role, Role.permission, Permission.name) -> name
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Now (`v3`)"
|
||||||
|
|
||||||
|
```
|
||||||
|
User.status -> status
|
||||||
|
(User.role, Role.name) -> role__name
|
||||||
|
(User.role, Role.permission, Permission.name) -> role__permission__name
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### `*_params` dependencies consolidated into per-paginate methods
|
||||||
|
|
||||||
|
The six individual dependency methods (`offset_params`, `cursor_params`, `paginate_params`, `filter_params`, `search_params`, `order_params`) have been **removed** and replaced by three consolidated methods that bundle pagination, search, filter, and order into a single `Depends()` call.
|
||||||
|
|
||||||
|
| Removed | Replacement |
|
||||||
|
|---|---|
|
||||||
|
| `offset_params()` + `filter_params()` + `search_params()` + `order_params()` | `offset_paginate_params()` |
|
||||||
|
| `cursor_params()` + `filter_params()` + `search_params()` + `order_params()` | `cursor_paginate_params()` |
|
||||||
|
| `paginate_params()` + `filter_params()` + `search_params()` + `order_params()` | `paginate_params()` |
|
||||||
|
|
||||||
|
Each new method accepts `search`, `filter`, and `order` boolean toggles (all `True` by default) to disable features you don't need.
|
||||||
|
|
||||||
|
=== "Before (`v2`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.crud import OrderByClause
|
||||||
|
|
||||||
|
@router.get("/offset")
|
||||||
|
async def list_articles_offset(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[dict, Depends(ArticleCrud.offset_params(default_page_size=20))],
|
||||||
|
filter_by: Annotated[dict, Depends(ArticleCrud.filter_params())],
|
||||||
|
order_by: Annotated[OrderByClause | None, Depends(ArticleCrud.order_params(default_field=Article.created_at))],
|
||||||
|
search: str | None = None,
|
||||||
|
) -> OffsetPaginatedResponse[ArticleRead]:
|
||||||
|
return await ArticleCrud.offset_paginate(
|
||||||
|
session=session,
|
||||||
|
**params,
|
||||||
|
search=search,
|
||||||
|
filter_by=filter_by or None,
|
||||||
|
order_by=order_by,
|
||||||
|
schema=ArticleRead,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Now (`v3`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
@router.get("/offset")
|
||||||
|
async def list_articles_offset(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[
|
||||||
|
dict,
|
||||||
|
Depends(
|
||||||
|
ArticleCrud.offset_paginate_params(
|
||||||
|
default_page_size=20,
|
||||||
|
default_order_field=Article.created_at,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
],
|
||||||
|
) -> OffsetPaginatedResponse[ArticleRead]:
|
||||||
|
return await ArticleCrud.offset_paginate(session=session, **params, schema=ArticleRead)
|
||||||
|
```
|
||||||
|
|
||||||
|
The same pattern applies to `cursor_paginate_params()` and `paginate_params()`. To disable a feature, pass the toggle:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# No search or ordering, only pagination + filtering
|
||||||
|
ArticleCrud.offset_paginate_params(search=False, order=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Models
|
||||||
|
|
||||||
|
The lifecycle event system has been rewritten. Callbacks are now registered with a module-level [`listens_for`](../reference/models.md#fastapi_toolsets.models.listens_for) decorator and dispatched by [`EventSession`](../reference/models.md#fastapi_toolsets.models.EventSession), replacing the mixin-based approach from `v2`.
|
||||||
|
|
||||||
|
### `WatchedFieldsMixin` and `@watch` removed
|
||||||
|
|
||||||
|
Importing `WatchedFieldsMixin` or `watch` will raise `ImportError`.
|
||||||
|
|
||||||
|
Model method callbacks (`on_create`, `on_delete`, `on_update`) and the `@watch` decorator are replaced by:
|
||||||
|
|
||||||
|
1. **`__watched_fields__`** — a plain class attribute to restrict which field changes trigger `UPDATE` events (replaces `@watch`).
|
||||||
|
2. **`@listens_for`** — a module-level decorator to register callbacks for one or more [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) types (replaces `on_create` / `on_delete` / `on_update` methods).
|
||||||
|
|
||||||
|
=== "Before (`v2`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.models import WatchedFieldsMixin, watch
|
||||||
|
|
||||||
|
@watch("status")
|
||||||
|
class Order(Base, UUIDMixin, WatchedFieldsMixin):
|
||||||
|
__tablename__ = "orders"
|
||||||
|
|
||||||
|
status: Mapped[str]
|
||||||
|
|
||||||
|
async def on_create(self):
|
||||||
|
await notify_new_order(self.id)
|
||||||
|
|
||||||
|
async def on_update(self, changes):
|
||||||
|
if "status" in changes:
|
||||||
|
await notify_status_change(self.id, changes["status"])
|
||||||
|
|
||||||
|
async def on_delete(self):
|
||||||
|
await notify_order_cancelled(self.id)
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Now (`v3`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.models import ModelEvent, UUIDMixin, listens_for
|
||||||
|
|
||||||
|
class Order(Base, UUIDMixin):
|
||||||
|
__tablename__ = "orders"
|
||||||
|
__watched_fields__ = ("status",)
|
||||||
|
|
||||||
|
status: Mapped[str]
|
||||||
|
|
||||||
|
@listens_for(Order, [ModelEvent.CREATE])
|
||||||
|
async def on_order_created(order: Order, event_type: ModelEvent, changes: None):
|
||||||
|
await notify_new_order(order.id)
|
||||||
|
|
||||||
|
@listens_for(Order, [ModelEvent.UPDATE])
|
||||||
|
async def on_order_updated(order: Order, event_type: ModelEvent, changes: dict):
|
||||||
|
if "status" in changes:
|
||||||
|
await notify_status_change(order.id, changes["status"])
|
||||||
|
|
||||||
|
@listens_for(Order, [ModelEvent.DELETE])
|
||||||
|
async def on_order_deleted(order: Order, event_type: ModelEvent, changes: None):
|
||||||
|
await notify_order_cancelled(order.id)
|
||||||
|
```
|
||||||
|
|
||||||
|
### `EventSession` now required
|
||||||
|
|
||||||
|
Without `EventSession`, lifecycle callbacks will silently stop firing.
|
||||||
|
|
||||||
|
Callbacks are now dispatched inside `EventSession.commit()` rather than via background tasks. Pass it as the session class when creating your session factory:
|
||||||
|
|
||||||
|
=== "Before (`v2`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
engine = create_async_engine("postgresql+asyncpg://...")
|
||||||
|
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Now (`v3`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
|
from fastapi_toolsets.models import EventSession
|
||||||
|
|
||||||
|
engine = create_async_engine("postgresql+asyncpg://...")
|
||||||
|
SessionLocal = async_sessionmaker(engine, expire_on_commit=False, class_=EventSession)
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
If you use `create_db_session` from `fastapi_toolsets.pytest`, the session already uses `EventSession` — no changes needed in tests.
|
||||||
93
docs/module/cli.md
Normal file
93
docs/module/cli.md
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
# CLI
|
||||||
|
|
||||||
|
Typer-based command-line interface for managing your FastAPI application, with built-in fixture commands integration.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
=== "uv"
|
||||||
|
``` bash
|
||||||
|
uv add "fastapi-toolsets[cli]"
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "pip"
|
||||||
|
``` bash
|
||||||
|
pip install "fastapi-toolsets[cli]"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `cli` module provides a `manager` entry point built with [Typer](https://typer.tiangolo.com/). It allow custom commands to be added in addition of the fixture commands when a [`FixtureRegistry`](../reference/fixtures.md#fastapi_toolsets.fixtures.registry.FixtureRegistry) and a database context are configured.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Configure the CLI in your `pyproject.toml`:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[tool.fastapi-toolsets]
|
||||||
|
cli = "myapp.cli:cli" # Custom Typer app
|
||||||
|
fixtures = "myapp.fixtures:registry" # FixtureRegistry instance
|
||||||
|
db_context = "myapp.db:db_context" # Async context manager for sessions
|
||||||
|
```
|
||||||
|
|
||||||
|
All fields are optional. Without configuration, the `manager` command still works but no command are available.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Manager commands
|
||||||
|
manager --help
|
||||||
|
|
||||||
|
Usage: manager [OPTIONS] COMMAND [ARGS]...
|
||||||
|
|
||||||
|
FastAPI utilities CLI.
|
||||||
|
|
||||||
|
╭─ Options ────────────────────────────────────────────────────────────────────────╮
|
||||||
|
│ --install-completion Install completion for the current shell. │
|
||||||
|
│ --show-completion Show completion for the current shell, to copy it │
|
||||||
|
│ or customize the installation. │
|
||||||
|
│ --help Show this message and exit. │
|
||||||
|
╰──────────────────────────────────────────────────────────────────────────────────╯
|
||||||
|
╭─ Commands ───────────────────────────────────────────────────────────────────────╮
|
||||||
|
│ check-db │
|
||||||
|
│ fixtures Manage database fixtures. │
|
||||||
|
╰──────────────────────────────────────────────────────────────────────────────────╯
|
||||||
|
|
||||||
|
# Fixtures commands
|
||||||
|
manager fixtures --help
|
||||||
|
|
||||||
|
Usage: manager fixtures [OPTIONS] COMMAND [ARGS]...
|
||||||
|
|
||||||
|
Manage database fixtures.
|
||||||
|
|
||||||
|
╭─ Options ────────────────────────────────────────────────────────────────────────╮
|
||||||
|
│ --help Show this message and exit. │
|
||||||
|
╰──────────────────────────────────────────────────────────────────────────────────╯
|
||||||
|
╭─ Commands ───────────────────────────────────────────────────────────────────────╮
|
||||||
|
│ list List all registered fixtures. │
|
||||||
|
│ load Load fixtures into the database. │
|
||||||
|
╰──────────────────────────────────────────────────────────────────────────────────╯
|
||||||
|
```
|
||||||
|
|
||||||
|
## Custom CLI
|
||||||
|
|
||||||
|
You can extend the CLI by providing your own Typer app. The `manager` entry point will merge your app's commands with the built-in ones:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# myapp/cli.py
|
||||||
|
import typer
|
||||||
|
|
||||||
|
cli = typer.Typer()
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
def hello():
|
||||||
|
print("Hello from my app!")
|
||||||
|
```
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[tool.fastapi-toolsets]
|
||||||
|
cli = "myapp.cli:cli"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/cli.md)
|
||||||
577
docs/module/crud.md
Normal file
577
docs/module/crud.md
Normal file
@@ -0,0 +1,577 @@
|
|||||||
|
# CRUD
|
||||||
|
|
||||||
|
Generic async CRUD operations for SQLAlchemy models with search, pagination, and many-to-many support.
|
||||||
|
|
||||||
|
!!! info
|
||||||
|
This module has been coded and tested to be compatible with PostgreSQL only.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `crud` module provides [`AsyncCrud`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud), a base class with a full suite of async database operations, and [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory), a convenience function to instantiate it for a given model.
|
||||||
|
|
||||||
|
## Creating a CRUD class
|
||||||
|
|
||||||
|
### Factory style
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.crud import CrudFactory
|
||||||
|
from myapp.models import User
|
||||||
|
|
||||||
|
UserCrud = CrudFactory(model=User)
|
||||||
|
```
|
||||||
|
|
||||||
|
[`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory) dynamically creates a class named `AsyncUserCrud` with `User` as its model. This is the most concise option for straightforward CRUD with no custom logic.
|
||||||
|
|
||||||
|
### Subclass style
|
||||||
|
|
||||||
|
!!! info "Added in `v2.3.0`"
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.crud.factory import AsyncCrud
|
||||||
|
from myapp.models import User
|
||||||
|
|
||||||
|
class UserCrud(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
searchable_fields = [User.username, User.email]
|
||||||
|
default_load_options = [selectinload(User.role)]
|
||||||
|
```
|
||||||
|
|
||||||
|
Subclassing [`AsyncCrud`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud) directly is the preferred style when you need to add custom methods or when the configuration is complex enough to benefit from a named class body.
|
||||||
|
|
||||||
|
### Adding custom methods
|
||||||
|
|
||||||
|
```python
|
||||||
|
class UserCrud(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_active(cls, session: AsyncSession) -> list[User]:
|
||||||
|
return await cls.get_multi(session, filters=[User.is_active == True])
|
||||||
|
```
|
||||||
|
|
||||||
|
### Sharing a custom base across multiple models
|
||||||
|
|
||||||
|
Define a generic base class with the shared methods, then subclass it for each model:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
from fastapi_toolsets.crud.factory import AsyncCrud
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=DeclarativeBase)
|
||||||
|
|
||||||
|
class AuditedCrud(AsyncCrud[T], Generic[T]):
|
||||||
|
"""Base CRUD with custom function"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_active(cls, session: AsyncSession):
|
||||||
|
return await cls.get_multi(session, filters=[cls.model.is_active == True])
|
||||||
|
|
||||||
|
|
||||||
|
class UserCrud(AuditedCrud[User]):
|
||||||
|
model = User
|
||||||
|
searchable_fields = [User.username, User.email]
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also use the factory shorthand with the same base by passing `base_class`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
UserCrud = CrudFactory(User, base_class=AuditedCrud)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Basic operations
|
||||||
|
|
||||||
|
!!! info "`get_or_none` added in `v2.2`"
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Create
|
||||||
|
user = await UserCrud.create(session=session, obj=UserCreateSchema(username="alice"))
|
||||||
|
|
||||||
|
# Get one (raises NotFoundError if not found)
|
||||||
|
user = await UserCrud.get(session=session, filters=[User.id == user_id])
|
||||||
|
|
||||||
|
# Get one or None (never raises)
|
||||||
|
user = await UserCrud.get_or_none(session=session, filters=[User.id == user_id])
|
||||||
|
|
||||||
|
# Get first or None
|
||||||
|
user = await UserCrud.first(session=session, filters=[User.email == email])
|
||||||
|
|
||||||
|
# Get multiple
|
||||||
|
users = await UserCrud.get_multi(session=session, filters=[User.is_active == True])
|
||||||
|
|
||||||
|
# Update
|
||||||
|
user = await UserCrud.update(session=session, obj=UserUpdateSchema(username="bob"), filters=[User.id == user_id])
|
||||||
|
|
||||||
|
# Delete
|
||||||
|
await UserCrud.delete(session=session, filters=[User.id == user_id])
|
||||||
|
|
||||||
|
# Count / exists
|
||||||
|
count = await UserCrud.count(session=session, filters=[User.is_active == True])
|
||||||
|
exists = await UserCrud.exists(session=session, filters=[User.email == email])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Fetching a single record
|
||||||
|
|
||||||
|
Three methods fetch a single record — choose based on how you want to handle the "not found" case and whether you need strict uniqueness:
|
||||||
|
|
||||||
|
| Method | Not found | Multiple results |
|
||||||
|
|---|---|---|
|
||||||
|
| `get` | raises `NotFoundError` | raises `MultipleResultsFound` |
|
||||||
|
| `get_or_none` | returns `None` | raises `MultipleResultsFound` |
|
||||||
|
| `first` | returns `None` | returns the first match silently |
|
||||||
|
|
||||||
|
Use `get` when the record must exist (e.g. a detail endpoint that should return 404):
|
||||||
|
|
||||||
|
```python
|
||||||
|
user = await UserCrud.get(session=session, filters=[User.id == user_id])
|
||||||
|
```
|
||||||
|
|
||||||
|
Use `get_or_none` when the record may not exist but you still want strict uniqueness enforcement:
|
||||||
|
|
||||||
|
```python
|
||||||
|
user = await UserCrud.get_or_none(session=session, filters=[User.email == email])
|
||||||
|
if user is None:
|
||||||
|
... # handle missing case without catching an exception
|
||||||
|
```
|
||||||
|
|
||||||
|
Use `first` when you only care about any one match and don't need uniqueness:
|
||||||
|
|
||||||
|
```python
|
||||||
|
user = await UserCrud.first(session=session, filters=[User.is_active == True])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pagination
|
||||||
|
|
||||||
|
!!! info "Added in `v1.1` (only offset_pagination via `paginate` if `<v1.1`)"
|
||||||
|
|
||||||
|
Three pagination methods are available. All return a typed response whose `pagination_type` field tells clients which strategy was used.
|
||||||
|
|
||||||
|
| | `offset_paginate` | `cursor_paginate` | `paginate` |
|
||||||
|
|---|---|---|---|
|
||||||
|
| Return type | `OffsetPaginatedResponse` | `CursorPaginatedResponse` | either, based on `pagination_type` param |
|
||||||
|
| Total count | Yes | No | / |
|
||||||
|
| Jump to arbitrary page | Yes | No | / |
|
||||||
|
| Performance on deep pages | Degrades | Constant | / |
|
||||||
|
| Stable under concurrent inserts | No | Yes | / |
|
||||||
|
| Use case | Admin panels, numbered pagination | Feeds, APIs, infinite scroll | single endpoint, both strategies |
|
||||||
|
|
||||||
|
### Offset pagination
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import Annotated
|
||||||
|
from fastapi import Depends
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def get_users(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[dict, Depends(UserCrud.offset_paginate_params())],
|
||||||
|
) -> OffsetPaginatedResponse[UserRead]:
|
||||||
|
return await UserCrud.offset_paginate(session=session, **params, schema=UserRead)
|
||||||
|
```
|
||||||
|
|
||||||
|
The [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) method returns an [`OffsetPaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.OffsetPaginatedResponse):
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"pagination_type": "offset",
|
||||||
|
"data": ["..."],
|
||||||
|
"pagination": {
|
||||||
|
"total_count": 100,
|
||||||
|
"pages": 5,
|
||||||
|
"page": 1,
|
||||||
|
"items_per_page": 20,
|
||||||
|
"has_more": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Skipping the COUNT query
|
||||||
|
|
||||||
|
!!! info "Added in `v2.4.1`"
|
||||||
|
|
||||||
|
By default `offset_paginate` runs two queries: one for the page items and one `COUNT(*)` for `total_count`. On large tables the `COUNT` can be expensive. Pass `include_total=False` to `offset_paginate_params()` to skip it:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@router.get("")
|
||||||
|
async def get_users(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[dict, Depends(UserCrud.offset_paginate_params(include_total=False))],
|
||||||
|
) -> OffsetPaginatedResponse[UserRead]:
|
||||||
|
return await UserCrud.offset_paginate(session=session, **params, schema=UserRead)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Cursor pagination
|
||||||
|
|
||||||
|
```python
|
||||||
|
@router.get("")
|
||||||
|
async def list_users(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[dict, Depends(UserCrud.cursor_paginate_params())],
|
||||||
|
) -> CursorPaginatedResponse[UserRead]:
|
||||||
|
return await UserCrud.cursor_paginate(session=session, **params, schema=UserRead)
|
||||||
|
```
|
||||||
|
|
||||||
|
The [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate) method returns a [`CursorPaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.CursorPaginatedResponse):
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"pagination_type": "cursor",
|
||||||
|
"data": ["..."],
|
||||||
|
"pagination": {
|
||||||
|
"next_cursor": "eyJ2YWx1ZSI6ICIzZjQ3YWM2OS0uLi4ifQ==",
|
||||||
|
"prev_cursor": null,
|
||||||
|
"items_per_page": 20,
|
||||||
|
"has_more": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Pass `next_cursor` as the `cursor` query parameter on the next request to advance to the next page. `prev_cursor` is set on pages 2+ and points back to the first item of the current page. Both are `null` when there is no adjacent page.
|
||||||
|
|
||||||
|
#### Choosing a cursor column
|
||||||
|
|
||||||
|
The cursor column is set once on [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory) via the `cursor_column` parameter. It must be monotonically ordered for stable results:
|
||||||
|
|
||||||
|
- Auto-increment integer PKs
|
||||||
|
- UUID v7 PKs
|
||||||
|
- Timestamps
|
||||||
|
|
||||||
|
!!! warning
|
||||||
|
Random UUID v4 PKs are **not** suitable as cursor columns because their ordering is non-deterministic.
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
`cursor_column` is required. Calling [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate) on a CRUD class that has no `cursor_column` configured raises a `ValueError`.
|
||||||
|
|
||||||
|
The cursor value is URL-safe base64-encoded (no padding) when returned to the client and decoded back to the correct Python type on the next request. The following SQLAlchemy column types are supported:
|
||||||
|
|
||||||
|
| SQLAlchemy type | Python type |
|
||||||
|
|---|---|
|
||||||
|
| `Integer`, `BigInteger`, `SmallInteger` | `int` |
|
||||||
|
| `Uuid` | `uuid.UUID` |
|
||||||
|
| `DateTime` | `datetime.datetime` |
|
||||||
|
| `Date` | `datetime.date` |
|
||||||
|
| `Float`, `Numeric` | `decimal.Decimal` |
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Paginate by the primary key
|
||||||
|
PostCrud = CrudFactory(model=Post, cursor_column=Post.id)
|
||||||
|
|
||||||
|
# Paginate by a timestamp column instead
|
||||||
|
PostCrud = CrudFactory(model=Post, cursor_column=Post.created_at)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Unified endpoint (both strategies)
|
||||||
|
|
||||||
|
!!! info "Added in `v2.3.0`"
|
||||||
|
|
||||||
|
[`paginate()`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.paginate) dispatches to `offset_paginate` or `cursor_paginate` based on a `pagination_type` query parameter, letting you expose **one endpoint** that supports both strategies. The `pagination_type` field in the response tells clients which strategy was used, enabling frontend discriminated-union typing.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.schemas import PaginatedResponse
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def list_users(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[dict, Depends(UserCrud.paginate_params())],
|
||||||
|
) -> PaginatedResponse[UserRead]:
|
||||||
|
return await UserCrud.paginate(session, **params, schema=UserRead)
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /users?pagination_type=offset&page=2&items_per_page=10
|
||||||
|
GET /users?pagination_type=cursor&cursor=eyJ2YWx1ZSI6...&items_per_page=10
|
||||||
|
```
|
||||||
|
|
||||||
|
## Search
|
||||||
|
|
||||||
|
Two search strategies are available, both compatible with [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) and [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate).
|
||||||
|
|
||||||
|
| | Full-text search | Faceted search |
|
||||||
|
|---|---|---|
|
||||||
|
| Input | Free-text string | Exact column values |
|
||||||
|
| Relationship support | Yes | Yes |
|
||||||
|
| Use case | Search bars | Filter dropdowns |
|
||||||
|
|
||||||
|
!!! info "You can use both search strategies in the same endpoint!"
|
||||||
|
|
||||||
|
### Full-text search
|
||||||
|
|
||||||
|
!!! info "Added in `v2.2.1`"
|
||||||
|
The model's primary key is always included in `searchable_fields` automatically, so searching by ID works out of the box without any configuration. When no `searchable_fields` are declared, only the primary key is searched.
|
||||||
|
|
||||||
|
Declare `searchable_fields` on the CRUD class. Relationship traversal is supported via tuples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
PostCrud = CrudFactory(
|
||||||
|
model=Post,
|
||||||
|
searchable_fields=[
|
||||||
|
Post.title,
|
||||||
|
Post.content,
|
||||||
|
(Post.author, User.username), # search across relationship
|
||||||
|
],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
You can override `searchable_fields` per call with `search_fields`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
result = await UserCrud.offset_paginate(
|
||||||
|
session=session,
|
||||||
|
search_fields=[User.country],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
This allows searching with both [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) and [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate):
|
||||||
|
|
||||||
|
```python
|
||||||
|
@router.get("")
|
||||||
|
async def get_users(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[dict, Depends(UserCrud.offset_paginate_params())],
|
||||||
|
) -> OffsetPaginatedResponse[UserRead]:
|
||||||
|
return await UserCrud.offset_paginate(session=session, **params, schema=UserRead)
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
@router.get("")
|
||||||
|
async def get_users(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[dict, Depends(UserCrud.cursor_paginate_params())],
|
||||||
|
) -> CursorPaginatedResponse[UserRead]:
|
||||||
|
return await UserCrud.cursor_paginate(session=session, **params, schema=UserRead)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Faceted search
|
||||||
|
|
||||||
|
!!! info "Added in `v1.2`"
|
||||||
|
|
||||||
|
Declare `facet_fields` on the CRUD class to return distinct column values alongside paginated results. This is useful for populating filter dropdowns or building faceted search UIs.
|
||||||
|
|
||||||
|
Facet fields use the same syntax as `searchable_fields` — direct columns or relationship tuples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
UserCrud = CrudFactory(
|
||||||
|
model=User,
|
||||||
|
facet_fields=[
|
||||||
|
User.status,
|
||||||
|
User.country,
|
||||||
|
(User.role, Role.name), # value from a related model
|
||||||
|
],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
You can override `facet_fields` per call:
|
||||||
|
|
||||||
|
```python
|
||||||
|
result = await UserCrud.offset_paginate(
|
||||||
|
session=session,
|
||||||
|
facet_fields=[User.country],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
The distinct values are returned in the `filter_attributes` field of [`PaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse):
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"data": ["..."],
|
||||||
|
"pagination": { "..." },
|
||||||
|
"filter_attributes": {
|
||||||
|
"status": ["active", "inactive"],
|
||||||
|
"country": ["DE", "FR", "US"],
|
||||||
|
"role__name": ["admin", "editor", "viewer"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Use `filter_by` to pass the client's chosen filter values directly — no need to build SQLAlchemy conditions by hand. Any unknown key raises [`InvalidFacetFilterError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError).
|
||||||
|
|
||||||
|
!!! info "The keys in `filter_by` are the same keys the client received in `filter_attributes`."
|
||||||
|
Keys use `__` as a separator for the full relationship chain. A direct column `User.status` produces `"status"`. A relationship tuple `(User.role, Role.name)` produces `"role__name"`. A deeper chain `(User.role, Role.permission, Permission.name)` produces `"role__permission__name"`.
|
||||||
|
|
||||||
|
`filter_by` and `filters` can be combined — both are applied with AND logic.
|
||||||
|
|
||||||
|
Facet filtering is built into the consolidated params dependencies. When `filter=True` (the default), facet fields are exposed as query parameters and collected into `filter_by` automatically:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
|
||||||
|
UserCrud = CrudFactory(
|
||||||
|
model=User,
|
||||||
|
facet_fields=[User.status, User.country, (User.role, Role.name)],
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("", response_model_exclude_none=True)
|
||||||
|
async def list_users(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[dict, Depends(UserCrud.offset_paginate_params())],
|
||||||
|
) -> OffsetPaginatedResponse[UserRead]:
|
||||||
|
return await UserCrud.offset_paginate(
|
||||||
|
session=session,
|
||||||
|
**params,
|
||||||
|
schema=UserRead,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Both single-value and multi-value query parameters work:
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /users?status=active → filter_by={"status": ["active"]}
|
||||||
|
GET /users?status=active&country=FR → filter_by={"status": ["active"], "country": ["FR"]}
|
||||||
|
GET /users?role__name=admin&role__name=editor → filter_by={"role__name": ["admin", "editor"]} (IN clause)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Sorting
|
||||||
|
|
||||||
|
!!! info "Added in `v1.3`"
|
||||||
|
|
||||||
|
Declare `order_fields` on the CRUD class to expose client-driven column ordering via `order_by` and `order` query parameters.
|
||||||
|
|
||||||
|
```python
|
||||||
|
UserCrud = CrudFactory(
|
||||||
|
model=User,
|
||||||
|
order_fields=[
|
||||||
|
User.name,
|
||||||
|
User.created_at,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Ordering is built into the consolidated params dependencies. When `order=True` (the default), `order_by` and `order` query parameters are exposed and resolved into an `OrderByClause` automatically:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def list_users(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[dict, Depends(UserCrud.offset_paginate_params(
|
||||||
|
default_order_field=User.created_at,
|
||||||
|
))],
|
||||||
|
) -> OffsetPaginatedResponse[UserRead]:
|
||||||
|
return await UserCrud.offset_paginate(session=session, **params, schema=UserRead)
|
||||||
|
```
|
||||||
|
|
||||||
|
The dependency adds two query parameters to the endpoint:
|
||||||
|
|
||||||
|
| Parameter | Type |
|
||||||
|
| ---------- | --------------- |
|
||||||
|
| `order_by` | `str | null` |
|
||||||
|
| `order` | `asc` or `desc` |
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /users?order_by=name&order=asc → ORDER BY users.name ASC
|
||||||
|
GET /users?order_by=name&order=desc → ORDER BY users.name DESC
|
||||||
|
```
|
||||||
|
|
||||||
|
An unknown `order_by` value raises [`InvalidOrderFieldError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidOrderFieldError) (HTTP 422).
|
||||||
|
|
||||||
|
You can also pass `order_fields` directly to override the class-level defaults:
|
||||||
|
|
||||||
|
```python
|
||||||
|
params = UserCrud.offset_paginate_params(order_fields=[User.name])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Relationship loading
|
||||||
|
|
||||||
|
!!! info "Added in `v1.1`"
|
||||||
|
|
||||||
|
By default, SQLAlchemy relationships are not loaded unless explicitly requested. Instead of using `lazy="selectin"` on model definitions (which is implicit and applies globally), define a `default_load_options` on the CRUD class to control loading strategy explicitly.
|
||||||
|
|
||||||
|
!!! warning
|
||||||
|
Avoid using `lazy="selectin"` on model relationships. It fires silently on every query, cannot be disabled per-call, and can cause unexpected cascading loads through deep relationship chains. Use `default_load_options` instead.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
ArticleCrud = CrudFactory(
|
||||||
|
model=Article,
|
||||||
|
default_load_options=[
|
||||||
|
selectinload(Article.category),
|
||||||
|
selectinload(Article.tags),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
`default_load_options` applies automatically to all read operations (`get`, `first`, `get_multi`, `offset_paginate`, `cursor_paginate`). When `load_options` is passed at call-site, it **fully replaces** `default_load_options` for that query — giving you precise per-call control:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Only loads category, tags are not loaded
|
||||||
|
article = await ArticleCrud.get(
|
||||||
|
session=session,
|
||||||
|
filters=[Article.id == article_id],
|
||||||
|
load_options=[selectinload(Article.category)],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Loads nothing — useful for write-then-refresh flows or lightweight checks
|
||||||
|
articles = await ArticleCrud.get_multi(session=session, load_options=[])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Many-to-many relationships
|
||||||
|
|
||||||
|
Use `m2m_fields` to map schema fields containing lists of IDs to SQLAlchemy relationships. The CRUD class resolves and validates all IDs before persisting:
|
||||||
|
|
||||||
|
```python
|
||||||
|
PostCrud = CrudFactory(
|
||||||
|
model=Post,
|
||||||
|
m2m_fields={"tag_ids": Post.tags},
|
||||||
|
)
|
||||||
|
|
||||||
|
post = await PostCrud.create(session=session, obj=PostCreateSchema(title="Hello", tag_ids=[1, 2, 3]))
|
||||||
|
```
|
||||||
|
|
||||||
|
## Upsert
|
||||||
|
|
||||||
|
Atomic `INSERT ... ON CONFLICT DO UPDATE` using PostgreSQL:
|
||||||
|
|
||||||
|
```python
|
||||||
|
await UserCrud.upsert(
|
||||||
|
session=session,
|
||||||
|
obj=UserCreateSchema(email="alice@example.com", username="alice"),
|
||||||
|
index_elements=[User.email],
|
||||||
|
set_={"username"},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Response serialization
|
||||||
|
|
||||||
|
!!! info "Added in `v1.1`"
|
||||||
|
|
||||||
|
Pass a Pydantic schema class to `create`, `get`, `update`, or `offset_paginate` to serialize the result directly into that schema and wrap it in a [`Response[schema]`](../reference/schemas.md#fastapi_toolsets.schemas.Response) or [`PaginatedResponse[schema]`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse):
|
||||||
|
|
||||||
|
```python
|
||||||
|
class UserRead(PydanticBase):
|
||||||
|
id: UUID
|
||||||
|
username: str
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/{uuid}",
|
||||||
|
responses=generate_error_responses(NotFoundError),
|
||||||
|
)
|
||||||
|
async def get_user(session: SessionDep, uuid: UUID) -> Response[UserRead]:
|
||||||
|
return await crud.UserCrud.get(
|
||||||
|
session=session,
|
||||||
|
filters=[User.id == uuid],
|
||||||
|
schema=UserRead,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def list_users(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[dict, Depends(crud.UserCrud.offset_paginate_params())],
|
||||||
|
) -> OffsetPaginatedResponse[UserRead]:
|
||||||
|
return await crud.UserCrud.offset_paginate(session=session, **params, schema=UserRead)
|
||||||
|
```
|
||||||
|
|
||||||
|
The schema must have `from_attributes=True` (or inherit from [`PydanticBase`](../reference/schemas.md#fastapi_toolsets.schemas.PydanticBase)) so it can be built from SQLAlchemy model instances.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/crud.md)
|
||||||
123
docs/module/db.md
Normal file
123
docs/module/db.md
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
# DB
|
||||||
|
|
||||||
|
SQLAlchemy async session management with transactions, table locking, and row-change polling.
|
||||||
|
|
||||||
|
!!! info
|
||||||
|
This module has been coded and tested to be compatible with PostgreSQL only.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `db` module provides helpers to create FastAPI dependencies and context managers for `AsyncSession`, along with utilities for nested transactions, table lock and polling for row changes.
|
||||||
|
|
||||||
|
## Session dependency
|
||||||
|
|
||||||
|
Use [`create_db_dependency`](../reference/db.md#fastapi_toolsets.db.create_db_dependency) to create a FastAPI dependency that yields a session and auto-commits on success:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||||
|
from fastapi_toolsets.db import create_db_dependency
|
||||||
|
|
||||||
|
engine = create_async_engine(url="postgresql+asyncpg://...", future=True)
|
||||||
|
session_maker = async_sessionmaker(bind=engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
get_db = create_db_dependency(session_maker=session_maker)
|
||||||
|
|
||||||
|
@router.get("/users")
|
||||||
|
async def list_users(session: AsyncSession = Depends(get_db)):
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Session context manager
|
||||||
|
|
||||||
|
Use [`create_db_context`](../reference/db.md#fastapi_toolsets.db.create_db_context) for sessions outside request handlers (e.g. background tasks, CLI commands):
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import create_db_context
|
||||||
|
|
||||||
|
db_context = create_db_context(session_maker=session_maker)
|
||||||
|
|
||||||
|
async def seed():
|
||||||
|
async with db_context() as session:
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Nested transactions
|
||||||
|
|
||||||
|
[`get_transaction`](../reference/db.md#fastapi_toolsets.db.get_transaction) handles savepoints automatically, allowing safe nesting:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import get_transaction
|
||||||
|
|
||||||
|
async def create_user_with_role(session=session):
|
||||||
|
async with get_transaction(session=session):
|
||||||
|
...
|
||||||
|
async with get_transaction(session=session): # uses savepoint
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Table locking
|
||||||
|
|
||||||
|
[`lock_tables`](../reference/db.md#fastapi_toolsets.db.lock_tables) acquires PostgreSQL table-level locks before executing critical sections:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import lock_tables
|
||||||
|
|
||||||
|
async with lock_tables(session=session, tables=[User], mode="EXCLUSIVE"):
|
||||||
|
# No other transaction can modify User until this block exits
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
Available lock modes are defined in [`LockMode`](../reference/db.md#fastapi_toolsets.db.LockMode): `ACCESS_SHARE`, `ROW_SHARE`, `ROW_EXCLUSIVE`, `SHARE_UPDATE_EXCLUSIVE`, `SHARE`, `SHARE_ROW_EXCLUSIVE`, `EXCLUSIVE`, `ACCESS_EXCLUSIVE`.
|
||||||
|
|
||||||
|
## Row-change polling
|
||||||
|
|
||||||
|
[`wait_for_row_change`](../reference/db.md#fastapi_toolsets.db.wait_for_row_change) polls a row until a specific column changes value, useful for waiting on async side effects:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import wait_for_row_change
|
||||||
|
|
||||||
|
# Wait up to 30s for order.status to change
|
||||||
|
await wait_for_row_change(
|
||||||
|
session=session,
|
||||||
|
model=Order,
|
||||||
|
pk_value=order_id,
|
||||||
|
columns=[Order.status],
|
||||||
|
interval=1.0,
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Creating a database
|
||||||
|
|
||||||
|
!!! info "Added in `v2.1`"
|
||||||
|
|
||||||
|
[`create_database`](../reference/db.md#fastapi_toolsets.db.create_database) creates a database at a given URL. It connects to *server_url* and issues a `CREATE DATABASE` statement:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import create_database
|
||||||
|
|
||||||
|
SERVER_URL = "postgresql+asyncpg://postgres:postgres@localhost/postgres"
|
||||||
|
|
||||||
|
await create_database(db_name="myapp_test", server_url=SERVER_URL)
|
||||||
|
```
|
||||||
|
|
||||||
|
For test isolation with automatic cleanup, use [`create_worker_database`](../reference/pytest.md#fastapi_toolsets.pytest.utils.create_worker_database) from the `pytest` module instead — it handles drop-before, create, and drop-after automatically.
|
||||||
|
|
||||||
|
## Cleaning up tables
|
||||||
|
|
||||||
|
!!! info "Added in `v2.1`"
|
||||||
|
|
||||||
|
[`cleanup_tables`](../reference/db.md#fastapi_toolsets.db.cleanup_tables) truncates all tables:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import cleanup_tables
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
async def clean(db_session):
|
||||||
|
yield
|
||||||
|
await cleanup_tables(session=db_session, base=Base)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/db.md)
|
||||||
61
docs/module/dependencies.md
Normal file
61
docs/module/dependencies.md
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
# Dependencies
|
||||||
|
|
||||||
|
FastAPI dependency factories for automatic model resolution from path and body parameters.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `dependencies` module provides two factory functions that create FastAPI dependencies to fetch a model instance from the database automatically — either from a path parameter or from a request body field — and inject it directly into your route handler.
|
||||||
|
|
||||||
|
## `PathDependency`
|
||||||
|
|
||||||
|
[`PathDependency`](../reference/dependencies.md#fastapi_toolsets.dependencies.PathDependency) resolves a model from a URL path parameter and injects it into the route handler. Raises [`NotFoundError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NotFoundError) automatically if the record does not exist.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.dependencies import PathDependency
|
||||||
|
|
||||||
|
# Plain callable
|
||||||
|
UserDep = PathDependency(model=User, field=User.id, session_dep=get_db)
|
||||||
|
|
||||||
|
# Annotated
|
||||||
|
SessionDep = Annotated[AsyncSession, Depends(get_db)]
|
||||||
|
UserDep = PathDependency(model=User, field=User.id, session_dep=SessionDep)
|
||||||
|
|
||||||
|
@router.get("/users/{user_id}")
|
||||||
|
async def get_user(user: User = UserDep):
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
By default the parameter name is inferred from the field (`user_id` for `User.id`). You can override it:
|
||||||
|
|
||||||
|
```python
|
||||||
|
UserDep = PathDependency(model=User, field=User.id, session_dep=get_db, param_name="id")
|
||||||
|
|
||||||
|
@router.get("/users/{id}")
|
||||||
|
async def get_user(user: User = UserDep):
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
## `BodyDependency`
|
||||||
|
|
||||||
|
[`BodyDependency`](../reference/dependencies.md#fastapi_toolsets.dependencies.BodyDependency) resolves a model from a field in the request body. Useful when a body contains a foreign key and you want the full object injected:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.dependencies import BodyDependency
|
||||||
|
|
||||||
|
# Plain callable
|
||||||
|
RoleDep = BodyDependency(model=Role, field=Role.id, session_dep=get_db, body_field="role_id")
|
||||||
|
|
||||||
|
# Annotated
|
||||||
|
SessionDep = Annotated[AsyncSession, Depends(get_db)]
|
||||||
|
RoleDep = BodyDependency(model=Role, field=Role.id, session_dep=SessionDep, body_field="role_id")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/users")
|
||||||
|
async def create_user(body: UserCreateSchema, role: Role = RoleDep):
|
||||||
|
user = User(username=body.username, role=role)
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/dependencies.md)
|
||||||
131
docs/module/exceptions.md
Normal file
131
docs/module/exceptions.md
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
# Exceptions
|
||||||
|
|
||||||
|
Structured API exceptions with consistent error responses and automatic OpenAPI documentation.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `exceptions` module provides a set of pre-built HTTP exceptions and a FastAPI exception handler that formats all errors — including validation errors — into a uniform [`ErrorResponse`](../reference/schemas.md#fastapi_toolsets.schemas.ErrorResponse).
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
Register the exception handlers on your FastAPI app at startup:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app=app)
|
||||||
|
```
|
||||||
|
|
||||||
|
This registers handlers for:
|
||||||
|
|
||||||
|
- [`ApiException`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ApiException) — all custom exceptions below
|
||||||
|
- `HTTPException` — Starlette/FastAPI HTTP errors
|
||||||
|
- `RequestValidationError` — Pydantic request validation (422)
|
||||||
|
- `ResponseValidationError` — Pydantic response validation (422)
|
||||||
|
- `Exception` — unhandled errors (500)
|
||||||
|
|
||||||
|
It also patches `app.openapi()` to replace the default Pydantic 422 schema with a structured example matching the `ErrorResponse` format.
|
||||||
|
|
||||||
|
## Built-in exceptions
|
||||||
|
|
||||||
|
| Exception | Status | Default message |
|
||||||
|
|-----------|--------|-----------------|
|
||||||
|
| [`UnauthorizedError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.UnauthorizedError) | 401 | Unauthorized |
|
||||||
|
| [`ForbiddenError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ForbiddenError) | 403 | Forbidden |
|
||||||
|
| [`NotFoundError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NotFoundError) | 404 | Not Found |
|
||||||
|
| [`ConflictError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ConflictError) | 409 | Conflict |
|
||||||
|
| [`NoSearchableFieldsError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NoSearchableFieldsError) | 400 | No Searchable Fields |
|
||||||
|
| [`InvalidFacetFilterError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError) | 400 | Invalid Facet Filter |
|
||||||
|
| [`InvalidOrderFieldError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidOrderFieldError) | 422 | Invalid Order Field |
|
||||||
|
|
||||||
|
### Per-instance overrides
|
||||||
|
|
||||||
|
All built-in exceptions accept optional keyword arguments to customise the response for a specific raise site without changing the class defaults:
|
||||||
|
|
||||||
|
| Argument | Effect |
|
||||||
|
|----------|--------|
|
||||||
|
| `detail` | Overrides both `str(exc)` (log output) and the `message` field in the response body |
|
||||||
|
| `desc` | Overrides the `description` field |
|
||||||
|
| `data` | Overrides the `data` field |
|
||||||
|
|
||||||
|
```python
|
||||||
|
raise NotFoundError(detail="User 42 not found", desc="No user with that ID exists in the database.")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Custom exceptions
|
||||||
|
|
||||||
|
Subclass [`ApiException`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ApiException) and define an `api_error` class variable:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.exceptions import ApiException
|
||||||
|
from fastapi_toolsets.schemas import ApiError
|
||||||
|
|
||||||
|
class PaymentRequiredError(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=402,
|
||||||
|
msg="Payment Required",
|
||||||
|
desc="Your subscription has expired.",
|
||||||
|
err_code="BILLING-402",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! warning
|
||||||
|
Subclasses that do not define `api_error` raise a `TypeError` at **class creation time**, not at raise time.
|
||||||
|
|
||||||
|
### Custom `__init__`
|
||||||
|
|
||||||
|
Override `__init__` to compute `detail`, `desc`, or `data` dynamically, then delegate to `super().__init__()`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class OrderValidationError(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=422,
|
||||||
|
msg="Order Validation Failed",
|
||||||
|
desc="One or more order fields are invalid.",
|
||||||
|
err_code="ORDER-422",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, *field_errors: str) -> None:
|
||||||
|
super().__init__(
|
||||||
|
f"{len(field_errors)} validation error(s)",
|
||||||
|
desc=", ".join(field_errors),
|
||||||
|
data={"errors": [{"message": e} for e in field_errors]},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Intermediate base classes
|
||||||
|
|
||||||
|
Use `abstract=True` when creating a shared base that is not meant to be raised directly:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class BillingError(ApiException, abstract=True):
|
||||||
|
"""Base for all billing-related errors."""
|
||||||
|
|
||||||
|
class PaymentRequiredError(BillingError):
|
||||||
|
api_error = ApiError(code=402, msg="Payment Required", desc="...", err_code="BILLING-402")
|
||||||
|
|
||||||
|
class SubscriptionExpiredError(BillingError):
|
||||||
|
api_error = ApiError(code=402, msg="Subscription Expired", desc="...", err_code="BILLING-402-EXP")
|
||||||
|
```
|
||||||
|
|
||||||
|
## OpenAPI response documentation
|
||||||
|
|
||||||
|
Use [`generate_error_responses`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.generate_error_responses) to add error schemas to your endpoint's OpenAPI spec:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.exceptions import generate_error_responses, NotFoundError, ForbiddenError
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/users/{id}",
|
||||||
|
responses=generate_error_responses(NotFoundError, ForbiddenError),
|
||||||
|
)
|
||||||
|
async def get_user(...): ...
|
||||||
|
```
|
||||||
|
|
||||||
|
Multiple exceptions sharing the same HTTP status code are grouped under one entry, each appearing as a named example keyed by its `err_code`. This keeps the OpenAPI UI readable when several error variants map to the same status.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/exceptions.md)
|
||||||
198
docs/module/fixtures.md
Normal file
198
docs/module/fixtures.md
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
# Fixtures
|
||||||
|
|
||||||
|
Dependency-aware database seeding with context-based loading strategies.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `fixtures` module lets you define named fixtures with dependencies between them, then load them into the database in the correct order. Fixtures can be scoped to contexts (e.g. base data, testing data) so that only the relevant ones are loaded for each environment.
|
||||||
|
|
||||||
|
## Defining fixtures
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.fixtures import FixtureRegistry, Context
|
||||||
|
|
||||||
|
fixtures = FixtureRegistry()
|
||||||
|
|
||||||
|
@fixtures.register
|
||||||
|
def roles():
|
||||||
|
return [
|
||||||
|
Role(id=1, name="admin"),
|
||||||
|
Role(id=2, name="user"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@fixtures.register(depends_on=["roles"], contexts=[Context.TESTING])
|
||||||
|
def test_users():
|
||||||
|
return [
|
||||||
|
User(id=1, username="alice", role_id=1),
|
||||||
|
User(id=2, username="bob", role_id=2),
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
Dependencies declared via `depends_on` are resolved topologically — `roles` will always be loaded before `test_users`.
|
||||||
|
|
||||||
|
## Loading fixtures
|
||||||
|
|
||||||
|
By context with [`load_fixtures_by_context`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.load_fixtures_by_context):
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.fixtures import load_fixtures_by_context
|
||||||
|
|
||||||
|
async with db_context() as session:
|
||||||
|
await load_fixtures_by_context(session, fixtures, Context.TESTING)
|
||||||
|
```
|
||||||
|
|
||||||
|
Directly by name with [`load_fixtures`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.load_fixtures):
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.fixtures import load_fixtures
|
||||||
|
|
||||||
|
async with db_context() as session:
|
||||||
|
await load_fixtures(session, fixtures, "roles", "test_users")
|
||||||
|
```
|
||||||
|
|
||||||
|
Both functions return a `dict[str, list[...]]` mapping each fixture name to the list of loaded instances.
|
||||||
|
|
||||||
|
## Contexts
|
||||||
|
|
||||||
|
[`Context`](../reference/fixtures.md#fastapi_toolsets.fixtures.enum.Context) is an enum with predefined values:
|
||||||
|
|
||||||
|
| Context | Description |
|
||||||
|
|---------|-------------|
|
||||||
|
| `Context.BASE` | Core data required in all environments |
|
||||||
|
| `Context.TESTING` | Data only loaded during tests |
|
||||||
|
| `Context.DEVELOPMENT` | Data only loaded in development |
|
||||||
|
| `Context.PRODUCTION` | Data only loaded in production |
|
||||||
|
|
||||||
|
A fixture with no `contexts` defined takes `Context.BASE` by default.
|
||||||
|
|
||||||
|
### Custom contexts
|
||||||
|
|
||||||
|
Plain strings and any `Enum` subclass are accepted wherever a `Context` enum is expected.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
class AppContext(str, Enum):
|
||||||
|
STAGING = "staging"
|
||||||
|
DEMO = "demo"
|
||||||
|
|
||||||
|
@fixtures.register(contexts=[AppContext.STAGING])
|
||||||
|
def staging_data():
|
||||||
|
return [Config(key="feature_x", enabled=True)]
|
||||||
|
|
||||||
|
await load_fixtures_by_context(session, fixtures, AppContext.STAGING)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Default context for a registry
|
||||||
|
|
||||||
|
Pass `contexts` to `FixtureRegistry` to set a default for all fixtures registered in it:
|
||||||
|
|
||||||
|
```python
|
||||||
|
testing_registry = FixtureRegistry(contexts=[Context.TESTING])
|
||||||
|
|
||||||
|
@testing_registry.register # implicitly contexts=[Context.TESTING]
|
||||||
|
def test_orders():
|
||||||
|
return [Order(id=1, total=99)]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Same fixture name, multiple context variants
|
||||||
|
|
||||||
|
The same fixture name may be registered under different (non-overlapping) context sets. When multiple contexts are loaded together, all matching variants are merged:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@fixtures.register(contexts=[Context.BASE])
|
||||||
|
def users():
|
||||||
|
return [User(id=1, username="admin")]
|
||||||
|
|
||||||
|
@fixtures.register(contexts=[Context.TESTING])
|
||||||
|
def users():
|
||||||
|
return [User(id=2, username="tester")]
|
||||||
|
|
||||||
|
# loads both admin and tester
|
||||||
|
await load_fixtures_by_context(session, fixtures, Context.BASE, Context.TESTING)
|
||||||
|
```
|
||||||
|
|
||||||
|
Registering two variants with overlapping context sets raises `ValueError`.
|
||||||
|
|
||||||
|
## Load strategies
|
||||||
|
|
||||||
|
[`LoadStrategy`](../reference/fixtures.md#fastapi_toolsets.fixtures.enum.LoadStrategy) controls how the fixture loader handles rows that already exist:
|
||||||
|
|
||||||
|
| Strategy | Description |
|
||||||
|
|----------|-------------|
|
||||||
|
| `LoadStrategy.INSERT` | Insert only, fail on duplicates |
|
||||||
|
| `LoadStrategy.MERGE` | Insert or update on conflict (default) |
|
||||||
|
| `LoadStrategy.SKIP_EXISTING` | Skip rows that already exist |
|
||||||
|
|
||||||
|
```python
|
||||||
|
await load_fixtures_by_context(
|
||||||
|
session, fixtures, Context.BASE, strategy=LoadStrategy.SKIP_EXISTING
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Merging registries
|
||||||
|
|
||||||
|
Split fixture definitions across modules and merge them:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from myapp.fixtures.dev import dev_fixtures
|
||||||
|
from myapp.fixtures.prod import prod_fixtures
|
||||||
|
|
||||||
|
fixtures = FixtureRegistry()
|
||||||
|
fixtures.include_registry(registry=dev_fixtures)
|
||||||
|
fixtures.include_registry(registry=prod_fixtures)
|
||||||
|
```
|
||||||
|
|
||||||
|
Fixtures with the same name are allowed as long as their context sets do not overlap. Conflicting contexts raise `ValueError`.
|
||||||
|
|
||||||
|
## Looking up fixture instances
|
||||||
|
|
||||||
|
[`get_obj_by_attr`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.get_obj_by_attr) retrieves a specific instance from a fixture function by attribute value — useful when building cross-fixture `depends_on` relationships:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.fixtures import get_obj_by_attr
|
||||||
|
|
||||||
|
@fixtures.register(depends_on=["roles"])
|
||||||
|
def users():
|
||||||
|
admin_role = get_obj_by_attr(roles, "name", "admin")
|
||||||
|
return [User(id=1, username="alice", role_id=admin_role.id)]
|
||||||
|
```
|
||||||
|
|
||||||
|
Raises `StopIteration` if no matching instance is found.
|
||||||
|
|
||||||
|
## Pytest integration
|
||||||
|
|
||||||
|
Use [`register_fixtures`](../reference/pytest.md#fastapi_toolsets.pytest.plugin.register_fixtures) to expose each fixture in your registry as an injectable pytest fixture named `fixture_{name}` by default:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# conftest.py
|
||||||
|
import pytest
|
||||||
|
from fastapi_toolsets.pytest import create_db_session, register_fixtures
|
||||||
|
from app.fixtures import registry
|
||||||
|
from app.models import Base
|
||||||
|
|
||||||
|
DATABASE_URL = "postgresql+asyncpg://user:pass@localhost/test_db"
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session():
|
||||||
|
async with create_db_session(database_url=DATABASE_URL, base=Base, cleanup=True) as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
register_fixtures(registry=registry, namespace=globals())
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# test_users.py
|
||||||
|
async def test_user_can_login(fixture_users: list[User], fixture_roles: list[Role]):
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
The load order is resolved automatically from the `depends_on` declarations in your registry. Each generated fixture receives `db_session` as a dependency and returns the list of loaded model instances.
|
||||||
|
|
||||||
|
## CLI integration
|
||||||
|
|
||||||
|
Fixtures can be triggered from the CLI. See the [CLI module](cli.md) for setup instructions.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/fixtures.md)
|
||||||
39
docs/module/logger.md
Normal file
39
docs/module/logger.md
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# Logger
|
||||||
|
|
||||||
|
Lightweight logging utilities with consistent formatting and uvicorn integration.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `logger` module provides two helpers: one to configure the root logger (and uvicorn loggers) at startup, and one to retrieve a named logger anywhere in your codebase.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
Call [`configure_logging`](../reference/logger.md#fastapi_toolsets.logger.configure_logging) once at application startup:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.logger import configure_logging
|
||||||
|
|
||||||
|
configure_logging(level="INFO")
|
||||||
|
```
|
||||||
|
|
||||||
|
This sets up a stdout handler with a consistent format and also configures uvicorn's access and error loggers so all log output shares the same style.
|
||||||
|
|
||||||
|
## Getting a logger
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__)
|
||||||
|
logger.info("User created")
|
||||||
|
```
|
||||||
|
|
||||||
|
When called without arguments, [`get_logger`](../reference/logger.md#fastapi_toolsets.logger.get_logger) auto-detects the caller's module name via frame inspection:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Equivalent to get_logger(name=__name__)
|
||||||
|
logger = get_logger()
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/logger.md)
|
||||||
109
docs/module/metrics.md
Normal file
109
docs/module/metrics.md
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
# Metrics
|
||||||
|
|
||||||
|
Prometheus metrics integration with a decorator-based registry and multi-process support.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
=== "uv"
|
||||||
|
``` bash
|
||||||
|
uv add "fastapi-toolsets[metrics]"
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "pip"
|
||||||
|
``` bash
|
||||||
|
pip install "fastapi-toolsets[metrics]"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `metrics` module provides a [`MetricsRegistry`](../reference/metrics.md#fastapi_toolsets.metrics.registry.MetricsRegistry) to declare Prometheus metrics with decorators, and an [`init_metrics`](../reference/metrics.md#fastapi_toolsets.metrics.handler.init_metrics) function to mount a `/metrics` endpoint on your FastAPI app.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi_toolsets.metrics import MetricsRegistry, init_metrics
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
metrics = MetricsRegistry()
|
||||||
|
|
||||||
|
init_metrics(app=app, registry=metrics)
|
||||||
|
```
|
||||||
|
|
||||||
|
This mounts the `/metrics` endpoint that Prometheus can scrape.
|
||||||
|
|
||||||
|
## Declaring metrics
|
||||||
|
|
||||||
|
### Providers
|
||||||
|
|
||||||
|
Providers are called once at startup by `init_metrics`. The return value (the Prometheus metric object) is stored in the registry and can be retrieved later with [`registry.get(name)`](../reference/metrics.md#fastapi_toolsets.metrics.registry.MetricsRegistry.get).
|
||||||
|
|
||||||
|
Use providers when you want **deferred initialization**: the Prometheus metric is not registered with the global `CollectorRegistry` until `init_metrics` runs, not at import time. This is particularly useful for testing — importing the module in a test suite without calling `init_metrics` leaves no metrics registered, avoiding cross-test pollution.
|
||||||
|
|
||||||
|
It is also useful when metrics are defined across multiple modules and merged with `include_registry`: any code that needs a metric can call `metrics.get()` on the shared registry instead of importing the metric directly from its origin module.
|
||||||
|
|
||||||
|
If neither of these applies to you, declaring metrics at module level (e.g. `HTTP_REQUESTS = Counter(...)`) is simpler and equally valid.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from prometheus_client import Counter, Histogram
|
||||||
|
|
||||||
|
@metrics.register
|
||||||
|
def http_requests():
|
||||||
|
return Counter("http_requests_total", "Total HTTP requests", ["method", "status"])
|
||||||
|
|
||||||
|
@metrics.register
|
||||||
|
def request_duration():
|
||||||
|
return Histogram("request_duration_seconds", "Request duration")
|
||||||
|
```
|
||||||
|
|
||||||
|
To use a provider's metric elsewhere (e.g. in a middleware), call `metrics.get()` inside the handler — **not** at module level, as providers are only initialized when `init_metrics` runs:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def metrics_middleware(request: Request, call_next):
|
||||||
|
response = await call_next(request)
|
||||||
|
metrics.get("http_requests").labels(
|
||||||
|
method=request.method, status=response.status_code
|
||||||
|
).inc()
|
||||||
|
return response
|
||||||
|
```
|
||||||
|
|
||||||
|
### Collectors
|
||||||
|
|
||||||
|
Collectors are called on every scrape. Use them for metrics that reflect current state (e.g. gauges).
|
||||||
|
|
||||||
|
!!! warning "Declare the metric at module level"
|
||||||
|
Do **not** instantiate the Prometheus metric inside the collector function. Doing so recreates it on every scrape, raising `ValueError: Duplicated timeseries in CollectorRegistry`. Declare it once at module level instead:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from prometheus_client import Gauge
|
||||||
|
|
||||||
|
_queue_depth = Gauge("queue_depth", "Current queue depth")
|
||||||
|
|
||||||
|
@metrics.register(collect=True)
|
||||||
|
def collect_queue_depth():
|
||||||
|
_queue_depth.set(get_current_queue_depth())
|
||||||
|
```
|
||||||
|
|
||||||
|
## Merging registries
|
||||||
|
|
||||||
|
Split metrics definitions across modules and merge them:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from myapp.metrics.http import http_metrics
|
||||||
|
from myapp.metrics.db import db_metrics
|
||||||
|
|
||||||
|
metrics = MetricsRegistry()
|
||||||
|
metrics.include_registry(registry=http_metrics)
|
||||||
|
metrics.include_registry(registry=db_metrics)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Multi-process mode
|
||||||
|
|
||||||
|
Multi-process support is enabled automatically when the `PROMETHEUS_MULTIPROC_DIR` environment variable is set. No code changes are required.
|
||||||
|
|
||||||
|
!!! warning "Environment variable name"
|
||||||
|
The correct variable is `PROMETHEUS_MULTIPROC_DIR` (not `PROMETHEUS_MULTIPROCESS_DIR`).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/metrics.md)
|
||||||
234
docs/module/models.md
Normal file
234
docs/module/models.md
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
# Models
|
||||||
|
|
||||||
|
!!! info "Added in `v2.0`"
|
||||||
|
|
||||||
|
Reusable SQLAlchemy 2.0 mixins for common column patterns, designed to be composed freely on any `DeclarativeBase` model.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `models` module provides mixins that each add a single, well-defined column behaviour. They work with standard SQLAlchemy 2.0 declarative syntax and are fully compatible with `AsyncSession`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.models import UUIDMixin, TimestampMixin
|
||||||
|
|
||||||
|
class Article(Base, UUIDMixin, TimestampMixin):
|
||||||
|
__tablename__ = "articles"
|
||||||
|
|
||||||
|
title: Mapped[str]
|
||||||
|
content: Mapped[str]
|
||||||
|
```
|
||||||
|
|
||||||
|
All timestamp columns are timezone-aware (`TIMESTAMPTZ`). All defaults are server-side (`clock_timestamp()`), so they are also applied when inserting rows via raw SQL outside the ORM.
|
||||||
|
|
||||||
|
## Mixins
|
||||||
|
|
||||||
|
### [`UUIDMixin`](../reference/models.md#fastapi_toolsets.models.UUIDMixin)
|
||||||
|
|
||||||
|
Adds a `id: UUID` primary key generated server-side by PostgreSQL using `gen_random_uuid()`. The value is retrieved via `RETURNING` after insert, so it is available on the Python object immediately after `flush()`.
|
||||||
|
|
||||||
|
!!! warning "Requires PostgreSQL 13+"
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.models import UUIDMixin
|
||||||
|
|
||||||
|
class User(Base, UUIDMixin):
|
||||||
|
__tablename__ = "users"
|
||||||
|
|
||||||
|
username: Mapped[str]
|
||||||
|
|
||||||
|
# id is None before flush
|
||||||
|
user = User(username="alice")
|
||||||
|
session.add(user)
|
||||||
|
await session.flush()
|
||||||
|
print(user.id) # UUID('...')
|
||||||
|
```
|
||||||
|
|
||||||
|
### [`UUIDv7Mixin`](../reference/models.md#fastapi_toolsets.models.UUIDv7Mixin)
|
||||||
|
|
||||||
|
!!! info "Added in `v2.3`"
|
||||||
|
|
||||||
|
Adds a `id: UUID` primary key generated server-side by PostgreSQL using `uuidv7()`. It's a time-ordered UUID format that encodes a millisecond-precision timestamp in the most significant bits, making it naturally sortable and index-friendly.
|
||||||
|
|
||||||
|
!!! warning "Requires PostgreSQL 18+"
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.models import UUIDv7Mixin
|
||||||
|
|
||||||
|
class Event(Base, UUIDv7Mixin):
|
||||||
|
__tablename__ = "events"
|
||||||
|
|
||||||
|
name: Mapped[str]
|
||||||
|
|
||||||
|
# id is None before flush
|
||||||
|
event = Event(name="user.signup")
|
||||||
|
session.add(event)
|
||||||
|
await session.flush()
|
||||||
|
print(event.id) # UUID('019...')
|
||||||
|
```
|
||||||
|
|
||||||
|
### [`CreatedAtMixin`](../reference/models.md#fastapi_toolsets.models.CreatedAtMixin)
|
||||||
|
|
||||||
|
Adds a `created_at: datetime` column set to `clock_timestamp()` on insert. The column has no `onupdate` hook — it is intentionally immutable after the row is created.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.models import UUIDMixin, CreatedAtMixin
|
||||||
|
|
||||||
|
class Order(Base, UUIDMixin, CreatedAtMixin):
|
||||||
|
__tablename__ = "orders"
|
||||||
|
|
||||||
|
total: Mapped[float]
|
||||||
|
```
|
||||||
|
|
||||||
|
### [`UpdatedAtMixin`](../reference/models.md#fastapi_toolsets.models.UpdatedAtMixin)
|
||||||
|
|
||||||
|
Adds an `updated_at: datetime` column set to `clock_timestamp()` on insert and automatically updated to `clock_timestamp()` on every ORM-level update (via SQLAlchemy's `onupdate` hook).
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.models import UUIDMixin, UpdatedAtMixin
|
||||||
|
|
||||||
|
class Post(Base, UUIDMixin, UpdatedAtMixin):
|
||||||
|
__tablename__ = "posts"
|
||||||
|
|
||||||
|
title: Mapped[str]
|
||||||
|
|
||||||
|
post = Post(title="Hello")
|
||||||
|
await session.flush()
|
||||||
|
await session.refresh(post)
|
||||||
|
|
||||||
|
post.title = "Hello World"
|
||||||
|
await session.flush()
|
||||||
|
await session.refresh(post)
|
||||||
|
print(post.updated_at)
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
`updated_at` is updated by SQLAlchemy at ORM flush time. If you update rows via raw SQL (e.g. `UPDATE posts SET ...`), the column will **not** be updated automatically — use a database trigger if you need that guarantee.
|
||||||
|
|
||||||
|
### [`TimestampMixin`](../reference/models.md#fastapi_toolsets.models.TimestampMixin)
|
||||||
|
|
||||||
|
Convenience mixin that combines [`CreatedAtMixin`](../reference/models.md#fastapi_toolsets.models.CreatedAtMixin) and [`UpdatedAtMixin`](../reference/models.md#fastapi_toolsets.models.UpdatedAtMixin). Equivalent to inheriting both.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.models import UUIDMixin, TimestampMixin
|
||||||
|
|
||||||
|
class Article(Base, UUIDMixin, TimestampMixin):
|
||||||
|
__tablename__ = "articles"
|
||||||
|
|
||||||
|
title: Mapped[str]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Lifecycle events
|
||||||
|
|
||||||
|
The event system provides lifecycle callbacks that fire **after commit**. If the transaction rolls back, no callback fires.
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
|
||||||
|
Event dispatch requires [`EventSession`](../reference/models.md#fastapi_toolsets.models.EventSession). Pass it as the session class when creating your session factory:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
|
from fastapi_toolsets.models import EventSession
|
||||||
|
|
||||||
|
engine = create_async_engine("postgresql+asyncpg://...")
|
||||||
|
SessionLocal = async_sessionmaker(engine, expire_on_commit=False, class_=EventSession)
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! info "Callbacks fire on `session.commit()` only — not on savepoints."
|
||||||
|
Savepoints created by [`get_transaction`](db.md) or `begin_nested()` do **not**
|
||||||
|
trigger callbacks. All events accumulated across flushes are dispatched once
|
||||||
|
when the outermost `commit()` is called.
|
||||||
|
|
||||||
|
### Events
|
||||||
|
|
||||||
|
Three event types are available, each corresponding to a [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) value:
|
||||||
|
|
||||||
|
| Event | Trigger |
|
||||||
|
|---|---|
|
||||||
|
| `ModelEvent.CREATE` | After `INSERT` commit |
|
||||||
|
| `ModelEvent.DELETE` | After `DELETE` commit |
|
||||||
|
| `ModelEvent.UPDATE` | After `UPDATE` commit on a watched field |
|
||||||
|
|
||||||
|
!!! warning "Callbacks fire only for ORM-level changes. Rows updated via raw SQL (`UPDATE ... SET ...`) are not detected."
|
||||||
|
|
||||||
|
### Watched fields
|
||||||
|
|
||||||
|
Set `__watched_fields__` on the model to restrict which field changes trigger `UPDATE` events. It must be a `tuple[str, ...]` — any other type raises `TypeError`:
|
||||||
|
|
||||||
|
| Class attribute | `UPDATE` behaviour |
|
||||||
|
|---|---|
|
||||||
|
| `__watched_fields__ = ("status", "role")` | Only fires when `status` or `role` changes |
|
||||||
|
| *(not set)* | Fires when **any** mapped field changes |
|
||||||
|
|
||||||
|
`__watched_fields__` is inherited through the class hierarchy via normal Python MRO. A subclass can override it:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class Order(Base, UUIDMixin):
|
||||||
|
__watched_fields__ = ("status",)
|
||||||
|
...
|
||||||
|
|
||||||
|
class UrgentOrder(Order):
|
||||||
|
# inherits __watched_fields__ = ("status",)
|
||||||
|
...
|
||||||
|
|
||||||
|
class PriorityOrder(Order):
|
||||||
|
__watched_fields__ = ("priority",)
|
||||||
|
# overrides parent — UPDATE fires only for priority changes
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
### Registering handlers
|
||||||
|
|
||||||
|
Register handlers with the [`listens_for`](../reference/models.md#fastapi_toolsets.models.listens_for) decorator. Every callback receives three arguments: the model instance, the [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) that triggered it, and a `changes` dict (`None` for `CREATE` and `DELETE`):
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.models import ModelEvent, UUIDMixin, listens_for
|
||||||
|
|
||||||
|
class Order(Base, UUIDMixin):
|
||||||
|
__tablename__ = "orders"
|
||||||
|
__watched_fields__ = ("status",)
|
||||||
|
|
||||||
|
status: Mapped[str]
|
||||||
|
|
||||||
|
@listens_for(Order, [ModelEvent.CREATE])
|
||||||
|
async def on_order_created(order: Order, event_type: ModelEvent, changes: None):
|
||||||
|
await notify_new_order(order.id)
|
||||||
|
|
||||||
|
@listens_for(Order, [ModelEvent.DELETE])
|
||||||
|
async def on_order_deleted(order: Order, event_type: ModelEvent, changes: None):
|
||||||
|
await notify_order_cancelled(order.id)
|
||||||
|
|
||||||
|
@listens_for(Order, [ModelEvent.UPDATE])
|
||||||
|
async def on_order_updated(order: Order, event_type: ModelEvent, changes: dict):
|
||||||
|
if "status" in changes:
|
||||||
|
await notify_status_change(order.id, changes["status"])
|
||||||
|
```
|
||||||
|
|
||||||
|
Multiple handlers can be registered for the same model and event. Handlers registered on a parent class also fire for subclass instances.
|
||||||
|
|
||||||
|
A single handler can listen for multiple events at once. When `event_types` is omitted, the handler fires for all events:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@listens_for(Order, [ModelEvent.CREATE, ModelEvent.UPDATE])
|
||||||
|
async def on_order_changed(order: Order, event_type: ModelEvent, changes: dict | None):
|
||||||
|
await invalidate_cache(order.id)
|
||||||
|
|
||||||
|
@listens_for(Order) # all events
|
||||||
|
async def on_any_order_event(order: Order, event_type: ModelEvent, changes: dict | None):
|
||||||
|
await audit_log(order.id, event_type)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Field changes format
|
||||||
|
|
||||||
|
The `changes` dict maps each watched field that changed to `{"old": ..., "new": ...}`. Only fields that actually changed are included. For `CREATE` and `DELETE` events, `changes` is `None`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# CREATE / DELETE → changes is None
|
||||||
|
# status changed → {"status": {"old": "pending", "new": "shipped"}}
|
||||||
|
# two fields changed → {"status": {...}, "assigned_to": {...}}
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! info "Multiple flushes in one transaction are merged: the earliest `old` and latest `new` are preserved, and `on_update` fires only once per commit."
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/models.md)
|
||||||
95
docs/module/pytest.md
Normal file
95
docs/module/pytest.md
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
# Pytest
|
||||||
|
|
||||||
|
Testing helpers for FastAPI applications with async client, database sessions, and parallel worker support.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
=== "uv"
|
||||||
|
``` bash
|
||||||
|
uv add "fastapi-toolsets[pytest]"
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "pip"
|
||||||
|
``` bash
|
||||||
|
pip install "fastapi-toolsets[pytest]"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `pytest` module provides utilities for setting up async test clients, managing test database sessions, and supporting parallel test execution with `pytest-xdist`.
|
||||||
|
|
||||||
|
## Creating an async client
|
||||||
|
|
||||||
|
Use [`create_async_client`](../reference/pytest.md#fastapi_toolsets.pytest.utils.create_async_client) to get an `httpx.AsyncClient` configured for your FastAPI app:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import create_async_client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def http_client(db_session):
|
||||||
|
async def _override_get_db():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
async with create_async_client(
|
||||||
|
app=app,
|
||||||
|
base_url="http://127.0.0.1/api/v1",
|
||||||
|
dependency_overrides={get_db: _override_get_db},
|
||||||
|
) as c:
|
||||||
|
yield c
|
||||||
|
```
|
||||||
|
|
||||||
|
## Database sessions in tests
|
||||||
|
|
||||||
|
Use [`create_db_session`](../reference/pytest.md#fastapi_toolsets.pytest.utils.create_db_session) to create an isolated `AsyncSession` for a test, combined with [`create_worker_database`](../reference/pytest.md#fastapi_toolsets.pytest.utils.create_worker_database) to set up a per-worker database:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import create_worker_database, create_db_session
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
async def worker_db_url():
|
||||||
|
async with create_worker_database(
|
||||||
|
database_url=str(settings.SQLALCHEMY_DATABASE_URI)
|
||||||
|
) as url:
|
||||||
|
yield url
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session(worker_db_url):
|
||||||
|
async with create_db_session(
|
||||||
|
database_url=worker_db_url, base=Base, cleanup=True
|
||||||
|
) as session:
|
||||||
|
yield session
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! info
|
||||||
|
In this example, the database is reset between each test using the argument `cleanup=True`.
|
||||||
|
|
||||||
|
Use [`worker_database_url`](../reference/pytest.md#fastapi_toolsets.pytest.utils.worker_database_url) to derive the per-worker URL manually if needed:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import worker_database_url
|
||||||
|
|
||||||
|
url = worker_database_url("postgresql+asyncpg://user:pass@localhost/test_db", default_test_db="test")
|
||||||
|
# e.g. "postgresql+asyncpg://user:pass@localhost/test_db_gw0" under xdist
|
||||||
|
```
|
||||||
|
|
||||||
|
## Parallel testing with pytest-xdist
|
||||||
|
|
||||||
|
The examples above are already compatible with parallel test execution with `pytest-xdist`.
|
||||||
|
|
||||||
|
## Cleaning up tables
|
||||||
|
|
||||||
|
If you want to manually clean up a database you can use [`cleanup_tables`](../reference/db.md#fastapi_toolsets.db.cleanup_tables), this will truncate all tables between tests for fast isolation:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import cleanup_tables
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
async def clean(db_session):
|
||||||
|
yield
|
||||||
|
await cleanup_tables(session=db_session, base=Base)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/pytest.md)
|
||||||
140
docs/module/schemas.md
Normal file
140
docs/module/schemas.md
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
# Schemas
|
||||||
|
|
||||||
|
Standardized Pydantic response models for consistent API responses across your FastAPI application.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `schemas` module provides generic response wrappers that enforce a uniform response structure. All models use `from_attributes=True` for ORM compatibility and `validate_assignment=True` for runtime type safety.
|
||||||
|
|
||||||
|
## Response models
|
||||||
|
|
||||||
|
### [`Response[T]`](../reference/schemas.md#fastapi_toolsets.schemas.Response)
|
||||||
|
|
||||||
|
The most common wrapper for a single resource response.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.schemas import Response
|
||||||
|
|
||||||
|
@router.get("/users/{id}")
|
||||||
|
async def get_user(user: User = UserDep) -> Response[UserSchema]:
|
||||||
|
return Response(data=user, message="User retrieved")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Paginated response models
|
||||||
|
|
||||||
|
Three classes wrap paginated list results. Pick the one that matches your endpoint's strategy:
|
||||||
|
|
||||||
|
| Class | `pagination` type | `pagination_type` field | Use when |
|
||||||
|
|---|---|---|---|
|
||||||
|
| [`OffsetPaginatedResponse[T]`](#offsetpaginatedresponset) | `OffsetPagination` | `"offset"` (fixed) | endpoint always uses offset |
|
||||||
|
| [`CursorPaginatedResponse[T]`](#cursorpaginatedresponset) | `CursorPagination` | `"cursor"` (fixed) | endpoint always uses cursor |
|
||||||
|
| [`PaginatedResponse[T]`](#paginatedresponset) | `OffsetPagination \| CursorPagination` | — | unified endpoint supporting both strategies |
|
||||||
|
|
||||||
|
#### [`OffsetPaginatedResponse[T]`](../reference/schemas.md#fastapi_toolsets.schemas.OffsetPaginatedResponse)
|
||||||
|
|
||||||
|
!!! info "Added in `v2.3.0`"
|
||||||
|
|
||||||
|
Use as the return type when the endpoint always uses [`offset_paginate`](crud.md#offset-pagination). The `pagination` field is guaranteed to be an [`OffsetPagination`](../reference/schemas.md#fastapi_toolsets.schemas.OffsetPagination) object; the response always includes a `pagination_type: "offset"` discriminator.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.schemas import OffsetPaginatedResponse
|
||||||
|
|
||||||
|
@router.get("/users")
|
||||||
|
async def list_users(
|
||||||
|
page: int = 1,
|
||||||
|
items_per_page: int = 20,
|
||||||
|
) -> OffsetPaginatedResponse[UserSchema]:
|
||||||
|
return await UserCrud.offset_paginate(
|
||||||
|
session, page=page, items_per_page=items_per_page, schema=UserSchema
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response shape:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"pagination_type": "offset",
|
||||||
|
"data": ["..."],
|
||||||
|
"pagination": {
|
||||||
|
"total_count": 100,
|
||||||
|
"page": 1,
|
||||||
|
"items_per_page": 20,
|
||||||
|
"has_more": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### [`CursorPaginatedResponse[T]`](../reference/schemas.md#fastapi_toolsets.schemas.CursorPaginatedResponse)
|
||||||
|
|
||||||
|
!!! info "Added in `v2.3.0`"
|
||||||
|
|
||||||
|
Use as the return type when the endpoint always uses [`cursor_paginate`](crud.md#cursor-pagination). The `pagination` field is guaranteed to be a [`CursorPagination`](../reference/schemas.md#fastapi_toolsets.schemas.CursorPagination) object; the response always includes a `pagination_type: "cursor"` discriminator.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.schemas import CursorPaginatedResponse
|
||||||
|
|
||||||
|
@router.get("/events")
|
||||||
|
async def list_events(
|
||||||
|
cursor: str | None = None,
|
||||||
|
items_per_page: int = 20,
|
||||||
|
) -> CursorPaginatedResponse[EventSchema]:
|
||||||
|
return await EventCrud.cursor_paginate(
|
||||||
|
session, cursor=cursor, items_per_page=items_per_page, schema=EventSchema
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response shape:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"pagination_type": "cursor",
|
||||||
|
"data": ["..."],
|
||||||
|
"pagination": {
|
||||||
|
"next_cursor": "eyJpZCI6IDQyfQ==",
|
||||||
|
"prev_cursor": null,
|
||||||
|
"items_per_page": 20,
|
||||||
|
"has_more": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### [`PaginatedResponse[T]`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse)
|
||||||
|
|
||||||
|
Return type for endpoints that support **both** pagination strategies via a `pagination_type` query parameter (using [`paginate()`](crud.md#unified-paginate--both-strategies-on-one-endpoint)).
|
||||||
|
|
||||||
|
When used as a return annotation, `PaginatedResponse[T]` automatically expands to `Annotated[Union[CursorPaginatedResponse[T], OffsetPaginatedResponse[T]], Field(discriminator="pagination_type")]`, so FastAPI emits a proper `oneOf` + discriminator in the OpenAPI schema with no extra boilerplate:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.crud import PaginationType
|
||||||
|
from fastapi_toolsets.schemas import PaginatedResponse
|
||||||
|
|
||||||
|
@router.get("/users")
|
||||||
|
async def list_users(
|
||||||
|
pagination_type: PaginationType = PaginationType.OFFSET,
|
||||||
|
page: int = 1,
|
||||||
|
cursor: str | None = None,
|
||||||
|
items_per_page: int = 20,
|
||||||
|
) -> PaginatedResponse[UserSchema]:
|
||||||
|
return await UserCrud.paginate(
|
||||||
|
session,
|
||||||
|
pagination_type=pagination_type,
|
||||||
|
page=page,
|
||||||
|
cursor=cursor,
|
||||||
|
items_per_page=items_per_page,
|
||||||
|
schema=UserSchema,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Pagination metadata models
|
||||||
|
|
||||||
|
The optional `filter_attributes` field is populated when `facet_fields` are configured on the CRUD class (see [Filter attributes](crud.md#filter-attributes-facets)). It is `None` by default and can be hidden from API responses with `response_model_exclude_none=True`.
|
||||||
|
|
||||||
|
### [`ErrorResponse`](../reference/schemas.md#fastapi_toolsets.schemas.ErrorResponse)
|
||||||
|
|
||||||
|
Returned automatically by the exceptions handler.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/schemas.md)
|
||||||
267
docs/module/security.md
Normal file
267
docs/module/security.md
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
# Security
|
||||||
|
|
||||||
|
Composable authentication helpers for FastAPI that use `Security()` for OpenAPI documentation and accept user-provided validator functions with full type flexibility.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `security` module provides four auth source classes and a `MultiAuth` factory. Each class wraps a FastAPI security scheme for OpenAPI and accepts a validator function called as:
|
||||||
|
|
||||||
|
```python
|
||||||
|
await validator(credential, **kwargs)
|
||||||
|
```
|
||||||
|
|
||||||
|
where `kwargs` are the extra keyword arguments provided at instantiation (roles, permissions, enums, etc.). The validator returns the authenticated identity (e.g. a `User` model) which becomes the route dependency value.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi import Security
|
||||||
|
from fastapi_toolsets.security import BearerTokenAuth
|
||||||
|
|
||||||
|
async def verify_token(token: str, *, role: str) -> User:
|
||||||
|
user = await db.get_by_token(token)
|
||||||
|
if not user or user.role != role:
|
||||||
|
raise UnauthorizedError()
|
||||||
|
return user
|
||||||
|
|
||||||
|
bearer_admin = BearerTokenAuth(verify_token, role="admin")
|
||||||
|
|
||||||
|
@app.get("/admin")
|
||||||
|
async def admin_route(user: User = Security(bearer_admin)):
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
## Auth sources
|
||||||
|
|
||||||
|
### [`BearerTokenAuth`](../reference/security.md#fastapi_toolsets.security.BearerTokenAuth)
|
||||||
|
|
||||||
|
Reads the `Authorization: Bearer <token>` header. Wraps `HTTPBearer` for OpenAPI.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.security import BearerTokenAuth
|
||||||
|
|
||||||
|
bearer = BearerTokenAuth(validator=verify_token)
|
||||||
|
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user: User = Security(bearer)):
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Token prefix
|
||||||
|
|
||||||
|
The optional `prefix` parameter restricts a `BearerTokenAuth` instance to tokens
|
||||||
|
that start with a given string. The prefix is **kept** in the value passed to the
|
||||||
|
validator — store and compare tokens with their prefix included.
|
||||||
|
|
||||||
|
This lets you deploy multiple `BearerTokenAuth` instances in the same application
|
||||||
|
and disambiguate them efficiently in `MultiAuth`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
user_bearer = BearerTokenAuth(verify_user, prefix="user_") # matches "Bearer user_..."
|
||||||
|
org_bearer = BearerTokenAuth(verify_org, prefix="org_") # matches "Bearer org_..."
|
||||||
|
```
|
||||||
|
|
||||||
|
Use [`generate_token()`](#token-generation) to create correctly-prefixed tokens.
|
||||||
|
|
||||||
|
#### Token generation
|
||||||
|
|
||||||
|
`BearerTokenAuth.generate_token()` produces a secure random token ready to store
|
||||||
|
in your database and return to the client. If a prefix is configured it is
|
||||||
|
prepended automatically:
|
||||||
|
|
||||||
|
```python
|
||||||
|
bearer = BearerTokenAuth(verify_token, prefix="user_")
|
||||||
|
|
||||||
|
token = bearer.generate_token() # e.g. "user_Xk3mN..."
|
||||||
|
await db.store_token(user_id, token)
|
||||||
|
return {"access_token": token, "token_type": "bearer"}
|
||||||
|
```
|
||||||
|
|
||||||
|
The client sends `Authorization: Bearer user_Xk3mN...` and the validator receives
|
||||||
|
the full token (prefix included) to compare against the stored value.
|
||||||
|
|
||||||
|
### [`CookieAuth`](../reference/security.md#fastapi_toolsets.security.CookieAuth)
|
||||||
|
|
||||||
|
Reads a named cookie. Wraps `APIKeyCookie` for OpenAPI.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.security import CookieAuth
|
||||||
|
|
||||||
|
cookie_auth = CookieAuth("session", validator=verify_session)
|
||||||
|
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user: User = Security(cookie_auth)):
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
### [`OAuth2Auth`](../reference/security.md#fastapi_toolsets.security.OAuth2Auth)
|
||||||
|
|
||||||
|
Reads the `Authorization: Bearer <token>` header and registers the token endpoint
|
||||||
|
in OpenAPI via `OAuth2PasswordBearer`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.security import OAuth2Auth
|
||||||
|
|
||||||
|
oauth2_auth = OAuth2Auth(token_url="/token", validator=verify_token)
|
||||||
|
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user: User = Security(oauth2_auth)):
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
### [`OpenIDAuth`](../reference/security.md#fastapi_toolsets.security.OpenIDAuth)
|
||||||
|
|
||||||
|
Reads the `Authorization: Bearer <token>` header and registers the OpenID Connect
|
||||||
|
discovery URL in OpenAPI via `OpenIdConnect`. Token validation is fully delegated
|
||||||
|
to your validator — use any OIDC / JWT library (`authlib`, `python-jose`, `PyJWT`).
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.security import OpenIDAuth
|
||||||
|
|
||||||
|
async def verify_google_token(token: str, *, audience: str) -> User:
|
||||||
|
payload = jwt.decode(token, google_public_keys, algorithms=["RS256"],
|
||||||
|
audience=audience)
|
||||||
|
return User(email=payload["email"], name=payload["name"])
|
||||||
|
|
||||||
|
google_auth = OpenIDAuth(
|
||||||
|
"https://accounts.google.com/.well-known/openid-configuration",
|
||||||
|
verify_google_token,
|
||||||
|
audience="my-client-id",
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get("/me")
|
||||||
|
async def me(user: User = Security(google_auth)):
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
The discovery URL is used **only for OpenAPI documentation** — no requests are made
|
||||||
|
to it by this class. You are responsible for fetching and caching the provider's
|
||||||
|
public keys in your validator.
|
||||||
|
|
||||||
|
Multiple providers work naturally with `MultiAuth`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
multi = MultiAuth(google_auth, github_auth)
|
||||||
|
|
||||||
|
@app.get("/data")
|
||||||
|
async def data(user: User = Security(multi)):
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
## Typed validator kwargs
|
||||||
|
|
||||||
|
All auth classes forward extra instantiation keyword arguments to the validator.
|
||||||
|
Arguments can be any type — enums, strings, integers, etc. The validator returns
|
||||||
|
the authenticated identity, which FastAPI injects directly into the route handler.
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def verify_token(token: str, *, role: Role, permission: str) -> User:
|
||||||
|
user = await decode_token(token)
|
||||||
|
if user.role != role or permission not in user.permissions:
|
||||||
|
raise UnauthorizedError()
|
||||||
|
return user
|
||||||
|
|
||||||
|
bearer = BearerTokenAuth(verify_token, role=Role.ADMIN, permission="billing:read")
|
||||||
|
```
|
||||||
|
|
||||||
|
Each auth instance is self-contained — create a separate instance per distinct
|
||||||
|
requirement instead of passing requirements through `Security(scopes=[...])`.
|
||||||
|
|
||||||
|
### Using `.require()` inline
|
||||||
|
|
||||||
|
If declaring a new top-level variable per role feels verbose, use `.require()` to
|
||||||
|
create a configured clone directly in the route decorator. The original instance
|
||||||
|
is not mutated:
|
||||||
|
|
||||||
|
```python
|
||||||
|
bearer = BearerTokenAuth(verify_token)
|
||||||
|
|
||||||
|
@app.get("/admin/stats")
|
||||||
|
async def admin_stats(user: User = Security(bearer.require(role=Role.ADMIN))):
|
||||||
|
return {"message": f"Hello admin {user.name}"}
|
||||||
|
|
||||||
|
@app.get("/profile")
|
||||||
|
async def profile(user: User = Security(bearer.require(role=Role.USER))):
|
||||||
|
return {"id": user.id, "name": user.name}
|
||||||
|
```
|
||||||
|
|
||||||
|
`.require()` kwargs are merged over existing ones — new values win on conflict.
|
||||||
|
The `prefix` (for `BearerTokenAuth`) and cookie name (for `CookieAuth`) are
|
||||||
|
always preserved.
|
||||||
|
|
||||||
|
`.require()` instances work transparently inside `MultiAuth`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
multi = MultiAuth(
|
||||||
|
user_bearer.require(role=Role.USER),
|
||||||
|
org_bearer.require(role=Role.ADMIN),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## MultiAuth
|
||||||
|
|
||||||
|
[`MultiAuth`](../reference/security.md#fastapi_toolsets.security.MultiAuth) combines
|
||||||
|
multiple auth sources into a single callable. Sources are tried in order; the
|
||||||
|
first one that finds a credential wins.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.security import MultiAuth
|
||||||
|
|
||||||
|
multi = MultiAuth(user_bearer, org_bearer, cookie_auth)
|
||||||
|
|
||||||
|
@app.get("/data")
|
||||||
|
async def data_route(user = Security(multi)):
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using `.require()` on MultiAuth
|
||||||
|
|
||||||
|
`MultiAuth` also supports `.require()`, which propagates the kwargs to every
|
||||||
|
source that implements it. Sources that do not (e.g. custom `AuthSource`
|
||||||
|
subclasses) are passed through unchanged:
|
||||||
|
|
||||||
|
```python
|
||||||
|
multi = MultiAuth(bearer, cookie)
|
||||||
|
|
||||||
|
@app.get("/admin")
|
||||||
|
async def admin(user: User = Security(multi.require(role=Role.ADMIN))):
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
This is equivalent to calling `.require()` on each source individually:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# These two are identical
|
||||||
|
multi.require(role=Role.ADMIN)
|
||||||
|
|
||||||
|
MultiAuth(
|
||||||
|
bearer.require(role=Role.ADMIN),
|
||||||
|
cookie.require(role=Role.ADMIN),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Prefix-based dispatch
|
||||||
|
|
||||||
|
Because `extract()` is pure string matching (no I/O), prefix-based source
|
||||||
|
selection is essentially free. Only the matching source's validator (which may
|
||||||
|
involve DB or network I/O) is ever called:
|
||||||
|
|
||||||
|
```python
|
||||||
|
user_bearer = BearerTokenAuth(verify_user, prefix="user_")
|
||||||
|
org_bearer = BearerTokenAuth(verify_org, prefix="org_")
|
||||||
|
|
||||||
|
multi = MultiAuth(user_bearer, org_bearer)
|
||||||
|
|
||||||
|
# "Bearer user_alice" → only verify_user runs, receives "user_alice"
|
||||||
|
# "Bearer org_acme" → only verify_org runs, receives "org_acme"
|
||||||
|
```
|
||||||
|
|
||||||
|
Tokens are stored and compared **with their prefix** — use `generate_token()` on
|
||||||
|
each source to issue correctly-prefixed tokens:
|
||||||
|
|
||||||
|
```python
|
||||||
|
user_token = user_bearer.generate_token() # "user_..."
|
||||||
|
org_token = org_bearer.generate_token() # "org_..."
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/security.md)
|
||||||
7
docs/overrides/main.html
Normal file
7
docs/overrides/main.html
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
{% extends "base.html" %} {% block extrahead %}
|
||||||
|
<script
|
||||||
|
defer
|
||||||
|
src="https://analytics.d3vyce.fr/script.js"
|
||||||
|
data-website-id="338b8816-7b99-4c6a-82f3-15595be3fd47"
|
||||||
|
></script>
|
||||||
|
{{ super() }} {% endblock %}
|
||||||
27
docs/reference/cli.md
Normal file
27
docs/reference/cli.md
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
# `cli`
|
||||||
|
|
||||||
|
Here's the reference for the CLI configuration helpers used to load settings from `pyproject.toml`.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.cli.config`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.cli.config import (
|
||||||
|
import_from_string,
|
||||||
|
get_config_value,
|
||||||
|
get_fixtures_registry,
|
||||||
|
get_db_context,
|
||||||
|
get_custom_cli,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.config.import_from_string
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.config.get_config_value
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.config.get_fixtures_registry
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.config.get_db_context
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.config.get_custom_cli
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.utils.async_command
|
||||||
20
docs/reference/crud.md
Normal file
20
docs/reference/crud.md
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
# `crud`
|
||||||
|
|
||||||
|
Here's the reference for the CRUD classes, factory, and search utilities.
|
||||||
|
|
||||||
|
You can import the main symbols from `fastapi_toolsets.crud`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.crud import CrudFactory, AsyncCrud
|
||||||
|
from fastapi_toolsets.crud.search import SearchConfig, get_searchable_fields, build_search_filters
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.crud.factory.AsyncCrud
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.crud.factory.CrudFactory
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.crud.search.SearchConfig
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.crud.search.get_searchable_fields
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.crud.search.build_search_filters
|
||||||
34
docs/reference/db.md
Normal file
34
docs/reference/db.md
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
# `db`
|
||||||
|
|
||||||
|
Here's the reference for all database session utilities, transaction helpers, and locking functions.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.db`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import (
|
||||||
|
LockMode,
|
||||||
|
cleanup_tables,
|
||||||
|
create_database,
|
||||||
|
create_db_dependency,
|
||||||
|
create_db_context,
|
||||||
|
get_transaction,
|
||||||
|
lock_tables,
|
||||||
|
wait_for_row_change,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.LockMode
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.create_db_dependency
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.create_db_context
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.get_transaction
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.lock_tables
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.wait_for_row_change
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.create_database
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.cleanup_tables
|
||||||
13
docs/reference/dependencies.md
Normal file
13
docs/reference/dependencies.md
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# `dependencies`
|
||||||
|
|
||||||
|
Here's the reference for the FastAPI dependency factory functions.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.dependencies`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.dependencies import PathDependency, BodyDependency
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.dependencies.PathDependency
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.dependencies.BodyDependency
|
||||||
40
docs/reference/exceptions.md
Normal file
40
docs/reference/exceptions.md
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
# `exceptions`
|
||||||
|
|
||||||
|
Here's the reference for all exception classes and handler utilities.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.exceptions`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.exceptions import (
|
||||||
|
ApiException,
|
||||||
|
UnauthorizedError,
|
||||||
|
ForbiddenError,
|
||||||
|
NotFoundError,
|
||||||
|
ConflictError,
|
||||||
|
NoSearchableFieldsError,
|
||||||
|
InvalidFacetFilterError,
|
||||||
|
InvalidOrderFieldError,
|
||||||
|
generate_error_responses,
|
||||||
|
init_exceptions_handlers,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.ApiException
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.UnauthorizedError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.ForbiddenError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.NotFoundError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.ConflictError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.NoSearchableFieldsError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.InvalidOrderFieldError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.generate_error_responses
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.handler.init_exceptions_handlers
|
||||||
31
docs/reference/fixtures.md
Normal file
31
docs/reference/fixtures.md
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
# `fixtures`
|
||||||
|
|
||||||
|
Here's the reference for the fixture registry, enums, and loading utilities.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.fixtures`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.fixtures import (
|
||||||
|
Context,
|
||||||
|
LoadStrategy,
|
||||||
|
Fixture,
|
||||||
|
FixtureRegistry,
|
||||||
|
load_fixtures,
|
||||||
|
load_fixtures_by_context,
|
||||||
|
get_obj_by_attr,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.enum.Context
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.enum.LoadStrategy
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.registry.Fixture
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.registry.FixtureRegistry
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.utils.load_fixtures
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.utils.load_fixtures_by_context
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.utils.get_obj_by_attr
|
||||||
13
docs/reference/logger.md
Normal file
13
docs/reference/logger.md
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# `logger`
|
||||||
|
|
||||||
|
Here's the reference for the logging utilities.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.logger`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.logger import configure_logging, get_logger
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.logger.configure_logging
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.logger.get_logger
|
||||||
15
docs/reference/metrics.md
Normal file
15
docs/reference/metrics.md
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# `metrics`
|
||||||
|
|
||||||
|
Here's the reference for the Prometheus metrics registry and endpoint handler.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.metrics`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.metrics import Metric, MetricsRegistry, init_metrics
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.metrics.registry.Metric
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.metrics.registry.MetricsRegistry
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.metrics.handler.init_metrics
|
||||||
34
docs/reference/models.md
Normal file
34
docs/reference/models.md
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
# `models`
|
||||||
|
|
||||||
|
Here's the reference for the SQLAlchemy model mixins provided by the `models` module.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.models`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.models import (
|
||||||
|
EventSession,
|
||||||
|
ModelEvent,
|
||||||
|
UUIDMixin,
|
||||||
|
UUIDv7Mixin,
|
||||||
|
CreatedAtMixin,
|
||||||
|
UpdatedAtMixin,
|
||||||
|
TimestampMixin,
|
||||||
|
listens_for,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.models.EventSession
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.models.ModelEvent
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.models.UUIDMixin
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.models.UUIDv7Mixin
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.models.CreatedAtMixin
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.models.UpdatedAtMixin
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.models.TimestampMixin
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.models.listens_for
|
||||||
26
docs/reference/pytest.md
Normal file
26
docs/reference/pytest.md
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# `pytest`
|
||||||
|
|
||||||
|
Here's the reference for all testing utilities and pytest fixtures.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.pytest`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import (
|
||||||
|
register_fixtures,
|
||||||
|
create_async_client,
|
||||||
|
create_db_session,
|
||||||
|
worker_database_url,
|
||||||
|
create_worker_database,
|
||||||
|
cleanup_tables,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.plugin.register_fixtures
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.utils.create_async_client
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.utils.create_db_session
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.utils.worker_database_url
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.utils.create_worker_database
|
||||||
46
docs/reference/schemas.md
Normal file
46
docs/reference/schemas.md
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
# `schemas`
|
||||||
|
|
||||||
|
Here's the reference for all response models and types provided by the `schemas` module.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.schemas`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.schemas import (
|
||||||
|
PydanticBase,
|
||||||
|
ResponseStatus,
|
||||||
|
ApiError,
|
||||||
|
BaseResponse,
|
||||||
|
Response,
|
||||||
|
ErrorResponse,
|
||||||
|
OffsetPagination,
|
||||||
|
CursorPagination,
|
||||||
|
PaginationType,
|
||||||
|
PaginatedResponse,
|
||||||
|
OffsetPaginatedResponse,
|
||||||
|
CursorPaginatedResponse,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.PydanticBase
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.ResponseStatus
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.ApiError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.BaseResponse
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.Response
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.ErrorResponse
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.OffsetPagination
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.CursorPagination
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.PaginationType
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.PaginatedResponse
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.OffsetPaginatedResponse
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.CursorPaginatedResponse
|
||||||
28
docs/reference/security.md
Normal file
28
docs/reference/security.md
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# `security`
|
||||||
|
|
||||||
|
Here's the reference for the authentication helpers provided by the `security` module.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.security`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.security import (
|
||||||
|
AuthSource,
|
||||||
|
BearerTokenAuth,
|
||||||
|
CookieAuth,
|
||||||
|
OAuth2Auth,
|
||||||
|
OpenIDAuth,
|
||||||
|
MultiAuth,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.security.AuthSource
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.security.BearerTokenAuth
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.security.CookieAuth
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.security.OAuth2Auth
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.security.OpenIDAuth
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.security.MultiAuth
|
||||||
0
docs_src/__init__.py
Normal file
0
docs_src/__init__.py
Normal file
0
docs_src/examples/__init__.py
Normal file
0
docs_src/examples/__init__.py
Normal file
0
docs_src/examples/authentication/__init__.py
Normal file
0
docs_src/examples/authentication/__init__.py
Normal file
9
docs_src/examples/authentication/app.py
Normal file
9
docs_src/examples/authentication/app.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
||||||
|
|
||||||
|
from .routes import router
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app=app)
|
||||||
|
app.include_router(router=router)
|
||||||
9
docs_src/examples/authentication/crud.py
Normal file
9
docs_src/examples/authentication/crud.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
from fastapi_toolsets.crud import CrudFactory
|
||||||
|
|
||||||
|
from .models import OAuthAccount, OAuthProvider, Team, User, UserToken
|
||||||
|
|
||||||
|
TeamCrud = CrudFactory(model=Team)
|
||||||
|
UserCrud = CrudFactory(model=User)
|
||||||
|
UserTokenCrud = CrudFactory(model=UserToken)
|
||||||
|
OAuthProviderCrud = CrudFactory(model=OAuthProvider)
|
||||||
|
OAuthAccountCrud = CrudFactory(model=OAuthAccount)
|
||||||
15
docs_src/examples/authentication/db.py
Normal file
15
docs_src/examples/authentication/db.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
from fastapi import Depends
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from fastapi_toolsets.db import create_db_context, create_db_dependency
|
||||||
|
|
||||||
|
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/postgres"
|
||||||
|
|
||||||
|
engine = create_async_engine(url=DATABASE_URL, future=True)
|
||||||
|
async_session_maker = async_sessionmaker(bind=engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
get_db = create_db_dependency(session_maker=async_session_maker)
|
||||||
|
get_db_context = create_db_context(session_maker=async_session_maker)
|
||||||
|
|
||||||
|
|
||||||
|
SessionDep = Depends(get_db)
|
||||||
105
docs_src/examples/authentication/models.py
Normal file
105
docs_src/examples/authentication/models.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
import enum
|
||||||
|
from datetime import datetime
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
Boolean,
|
||||||
|
DateTime,
|
||||||
|
Enum,
|
||||||
|
ForeignKey,
|
||||||
|
Integer,
|
||||||
|
String,
|
||||||
|
UniqueConstraint,
|
||||||
|
)
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID as PG_UUID
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from fastapi_toolsets.models import TimestampMixin, UUIDMixin
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase, UUIDMixin):
|
||||||
|
type_annotation_map = {
|
||||||
|
str: String(),
|
||||||
|
int: Integer(),
|
||||||
|
UUID: PG_UUID(as_uuid=True),
|
||||||
|
datetime: DateTime(timezone=True),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class UserRole(enum.Enum):
|
||||||
|
admin = "admin"
|
||||||
|
moderator = "moderator"
|
||||||
|
user = "user"
|
||||||
|
|
||||||
|
|
||||||
|
class Team(Base, TimestampMixin):
|
||||||
|
__tablename__ = "teams"
|
||||||
|
|
||||||
|
name: Mapped[str] = mapped_column(String, unique=True, index=True)
|
||||||
|
users: Mapped[list["User"]] = relationship(back_populates="team")
|
||||||
|
|
||||||
|
|
||||||
|
class User(Base, TimestampMixin):
|
||||||
|
__tablename__ = "users"
|
||||||
|
|
||||||
|
username: Mapped[str] = mapped_column(String, unique=True, index=True)
|
||||||
|
email: Mapped[str | None] = mapped_column(
|
||||||
|
String, unique=True, index=True, nullable=True
|
||||||
|
)
|
||||||
|
hashed_password: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||||
|
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||||
|
role: Mapped[UserRole] = mapped_column(Enum(UserRole), default=UserRole.user)
|
||||||
|
|
||||||
|
team_id: Mapped[UUID | None] = mapped_column(ForeignKey("teams.id"), nullable=True)
|
||||||
|
team: Mapped["Team | None"] = relationship(back_populates="users")
|
||||||
|
oauth_accounts: Mapped[list["OAuthAccount"]] = relationship(back_populates="user")
|
||||||
|
tokens: Mapped[list["UserToken"]] = relationship(back_populates="user")
|
||||||
|
|
||||||
|
|
||||||
|
class UserToken(Base, TimestampMixin):
|
||||||
|
"""API tokens for a user (multiple allowed)."""
|
||||||
|
|
||||||
|
__tablename__ = "user_tokens"
|
||||||
|
|
||||||
|
user_id: Mapped[UUID] = mapped_column(ForeignKey("users.id"))
|
||||||
|
# Store hashed token value
|
||||||
|
token_hash: Mapped[str] = mapped_column(String, unique=True, index=True)
|
||||||
|
name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||||
|
expires_at: Mapped[datetime | None] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
|
user: Mapped["User"] = relationship(back_populates="tokens")
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthProvider(Base, TimestampMixin):
|
||||||
|
"""Configurable OAuth2 / OpenID Connect provider."""
|
||||||
|
|
||||||
|
__tablename__ = "oauth_providers"
|
||||||
|
|
||||||
|
slug: Mapped[str] = mapped_column(String, unique=True, index=True)
|
||||||
|
name: Mapped[str] = mapped_column(String)
|
||||||
|
client_id: Mapped[str] = mapped_column(String)
|
||||||
|
client_secret: Mapped[str] = mapped_column(String)
|
||||||
|
discovery_url: Mapped[str] = mapped_column(String, nullable=False)
|
||||||
|
scopes: Mapped[str] = mapped_column(String, default="openid email profile")
|
||||||
|
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||||
|
|
||||||
|
accounts: Mapped[list["OAuthAccount"]] = relationship(back_populates="provider")
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthAccount(Base, TimestampMixin):
|
||||||
|
"""OAuth2 / OpenID Connect account linked to a user."""
|
||||||
|
|
||||||
|
__tablename__ = "oauth_accounts"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint("provider_id", "subject", name="uq_oauth_provider_subject"),
|
||||||
|
)
|
||||||
|
|
||||||
|
user_id: Mapped[UUID] = mapped_column(ForeignKey("users.id"))
|
||||||
|
provider_id: Mapped[UUID] = mapped_column(ForeignKey("oauth_providers.id"))
|
||||||
|
# OAuth `sub` / OpenID subject identifier
|
||||||
|
subject: Mapped[str] = mapped_column(String)
|
||||||
|
|
||||||
|
user: Mapped["User"] = relationship(back_populates="oauth_accounts")
|
||||||
|
provider: Mapped["OAuthProvider"] = relationship(back_populates="accounts")
|
||||||
122
docs_src/examples/authentication/routes.py
Normal file
122
docs_src/examples/authentication/routes.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import bcrypt
|
||||||
|
from fastapi import APIRouter, Form, HTTPException, Response, Security
|
||||||
|
|
||||||
|
from fastapi_toolsets.dependencies import PathDependency
|
||||||
|
|
||||||
|
from .crud import UserCrud, UserTokenCrud
|
||||||
|
from .db import SessionDep
|
||||||
|
from .models import OAuthProvider, User, UserToken
|
||||||
|
from .schemas import (
|
||||||
|
ApiTokenCreateRequest,
|
||||||
|
ApiTokenResponse,
|
||||||
|
RegisterRequest,
|
||||||
|
UserCreate,
|
||||||
|
UserResponse,
|
||||||
|
)
|
||||||
|
from .security import auth, cookie_auth, create_api_token
|
||||||
|
|
||||||
|
ProviderDep = PathDependency(
|
||||||
|
model=OAuthProvider,
|
||||||
|
field=OAuthProvider.slug,
|
||||||
|
session_dep=SessionDep,
|
||||||
|
param_name="slug",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def hash_password(password: str) -> str:
|
||||||
|
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password(plain: str, hashed: str) -> bool:
|
||||||
|
return bcrypt.checkpw(plain.encode(), hashed.encode())
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/auth")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/register", response_model=UserResponse, status_code=201)
|
||||||
|
async def register(body: RegisterRequest, session: SessionDep):
|
||||||
|
existing = await UserCrud.first(
|
||||||
|
session=session, filters=[User.username == body.username]
|
||||||
|
)
|
||||||
|
if existing:
|
||||||
|
raise HTTPException(status_code=409, detail="Username already taken")
|
||||||
|
|
||||||
|
user = await UserCrud.create(
|
||||||
|
session=session,
|
||||||
|
obj=UserCreate(
|
||||||
|
username=body.username,
|
||||||
|
email=body.email,
|
||||||
|
hashed_password=hash_password(body.password),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/token", status_code=204)
|
||||||
|
async def login(
|
||||||
|
session: SessionDep,
|
||||||
|
response: Response,
|
||||||
|
username: Annotated[str, Form()],
|
||||||
|
password: Annotated[str, Form()],
|
||||||
|
):
|
||||||
|
user = await UserCrud.first(session=session, filters=[User.username == username])
|
||||||
|
|
||||||
|
if (
|
||||||
|
not user
|
||||||
|
or not user.hashed_password
|
||||||
|
or not verify_password(password, user.hashed_password)
|
||||||
|
):
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||||
|
|
||||||
|
if not user.is_active:
|
||||||
|
raise HTTPException(status_code=403, detail="Account disabled")
|
||||||
|
|
||||||
|
cookie_auth.set_cookie(response, str(user.id))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/logout", status_code=204)
|
||||||
|
async def logout(response: Response):
|
||||||
|
cookie_auth.delete_cookie(response)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me", response_model=UserResponse)
|
||||||
|
async def me(user: User = Security(auth)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/tokens", response_model=ApiTokenResponse, status_code=201)
|
||||||
|
async def create_token(
|
||||||
|
body: ApiTokenCreateRequest,
|
||||||
|
user: User = Security(auth),
|
||||||
|
):
|
||||||
|
raw, token_row = await create_api_token(
|
||||||
|
user.id, name=body.name, expires_at=body.expires_at
|
||||||
|
)
|
||||||
|
return ApiTokenResponse(
|
||||||
|
id=token_row.id,
|
||||||
|
name=token_row.name,
|
||||||
|
expires_at=token_row.expires_at,
|
||||||
|
created_at=token_row.created_at,
|
||||||
|
token=raw,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/tokens/{token_id}", status_code=204)
|
||||||
|
async def revoke_token(
|
||||||
|
session: SessionDep,
|
||||||
|
token_id: UUID,
|
||||||
|
user: User = Security(auth),
|
||||||
|
):
|
||||||
|
if not await UserTokenCrud.first(
|
||||||
|
session=session,
|
||||||
|
filters=[UserToken.id == token_id, UserToken.user_id == user.id],
|
||||||
|
):
|
||||||
|
raise HTTPException(status_code=404, detail="Token not found")
|
||||||
|
await UserTokenCrud.delete(
|
||||||
|
session=session,
|
||||||
|
filters=[UserToken.id == token_id, UserToken.user_id == user.id],
|
||||||
|
)
|
||||||
64
docs_src/examples/authentication/schemas.py
Normal file
64
docs_src/examples/authentication/schemas.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import EmailStr
|
||||||
|
|
||||||
|
from fastapi_toolsets.schemas import PydanticBase
|
||||||
|
|
||||||
|
|
||||||
|
class RegisterRequest(PydanticBase):
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
email: EmailStr | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class UserResponse(PydanticBase):
|
||||||
|
id: UUID
|
||||||
|
username: str
|
||||||
|
email: str | None
|
||||||
|
role: str
|
||||||
|
is_active: bool
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class ApiTokenCreateRequest(PydanticBase):
|
||||||
|
name: str | None = None
|
||||||
|
expires_at: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ApiTokenResponse(PydanticBase):
|
||||||
|
id: UUID
|
||||||
|
name: str | None
|
||||||
|
expires_at: datetime | None
|
||||||
|
created_at: datetime
|
||||||
|
# Only populated on creation
|
||||||
|
token: str | None = None
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthProviderResponse(PydanticBase):
|
||||||
|
slug: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class UserCreate(PydanticBase):
|
||||||
|
username: str
|
||||||
|
email: str | None = None
|
||||||
|
hashed_password: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class UserTokenCreate(PydanticBase):
|
||||||
|
user_id: UUID
|
||||||
|
token_hash: str
|
||||||
|
name: str | None = None
|
||||||
|
expires_at: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthAccountCreate(PydanticBase):
|
||||||
|
user_id: UUID
|
||||||
|
provider_id: UUID
|
||||||
|
subject: str
|
||||||
100
docs_src/examples/authentication/security.py
Normal file
100
docs_src/examples/authentication/security.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
import hashlib
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
from fastapi_toolsets.exceptions import UnauthorizedError
|
||||||
|
from fastapi_toolsets.security import (
|
||||||
|
APIKeyHeaderAuth,
|
||||||
|
BearerTokenAuth,
|
||||||
|
CookieAuth,
|
||||||
|
MultiAuth,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .crud import UserCrud, UserTokenCrud
|
||||||
|
from .db import get_db_context
|
||||||
|
from .models import User, UserRole, UserToken
|
||||||
|
from .schemas import UserTokenCreate
|
||||||
|
|
||||||
|
SESSION_COOKIE = "session"
|
||||||
|
SECRET_KEY = "123456789"
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_token(token: str) -> str:
|
||||||
|
return hashlib.sha256(token.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
async def _verify_token(token: str, role: UserRole | None = None) -> User:
|
||||||
|
async with get_db_context() as db:
|
||||||
|
user_token = await UserTokenCrud.first(
|
||||||
|
session=db,
|
||||||
|
filters=[UserToken.token_hash == _hash_token(token)],
|
||||||
|
load_options=[selectinload(UserToken.user)],
|
||||||
|
)
|
||||||
|
|
||||||
|
if user_token is None or not user_token.user.is_active:
|
||||||
|
raise UnauthorizedError()
|
||||||
|
|
||||||
|
if user_token.expires_at and user_token.expires_at < datetime.now(timezone.utc):
|
||||||
|
raise UnauthorizedError()
|
||||||
|
|
||||||
|
user = user_token.user
|
||||||
|
|
||||||
|
if role is not None and user.role != role:
|
||||||
|
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
async def _verify_cookie(user_id: str, role: UserRole | None = None) -> User:
|
||||||
|
async with get_db_context() as db:
|
||||||
|
user = await UserCrud.first(
|
||||||
|
session=db,
|
||||||
|
filters=[User.id == UUID(user_id)],
|
||||||
|
)
|
||||||
|
|
||||||
|
if not user or not user.is_active:
|
||||||
|
raise UnauthorizedError()
|
||||||
|
|
||||||
|
if role is not None and user.role != role:
|
||||||
|
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
bearer_auth = BearerTokenAuth(
|
||||||
|
validator=_verify_token,
|
||||||
|
prefix="ctf_",
|
||||||
|
)
|
||||||
|
header_auth = APIKeyHeaderAuth(
|
||||||
|
name="X-API-Key",
|
||||||
|
validator=_verify_token,
|
||||||
|
)
|
||||||
|
cookie_auth = CookieAuth(
|
||||||
|
name=SESSION_COOKIE,
|
||||||
|
validator=_verify_cookie,
|
||||||
|
secret_key=SECRET_KEY,
|
||||||
|
)
|
||||||
|
auth = MultiAuth(bearer_auth, header_auth, cookie_auth)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_api_token(
|
||||||
|
user_id: UUID,
|
||||||
|
*,
|
||||||
|
name: str | None = None,
|
||||||
|
expires_at: datetime | None = None,
|
||||||
|
) -> tuple[str, UserToken]:
|
||||||
|
raw = bearer_auth.generate_token()
|
||||||
|
async with get_db_context() as db:
|
||||||
|
token_row = await UserTokenCrud.create(
|
||||||
|
session=db,
|
||||||
|
obj=UserTokenCreate(
|
||||||
|
user_id=user_id,
|
||||||
|
token_hash=_hash_token(raw),
|
||||||
|
name=name,
|
||||||
|
expires_at=expires_at,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return raw, token_row
|
||||||
0
docs_src/examples/pagination_search/__init__.py
Normal file
0
docs_src/examples/pagination_search/__init__.py
Normal file
9
docs_src/examples/pagination_search/app.py
Normal file
9
docs_src/examples/pagination_search/app.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
||||||
|
|
||||||
|
from .routes import router
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app=app)
|
||||||
|
app.include_router(router=router)
|
||||||
21
docs_src/examples/pagination_search/crud.py
Normal file
21
docs_src/examples/pagination_search/crud.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
from fastapi_toolsets.crud import CrudFactory
|
||||||
|
|
||||||
|
from .models import Article, Category
|
||||||
|
|
||||||
|
ArticleCrud = CrudFactory(
|
||||||
|
model=Article,
|
||||||
|
cursor_column=Article.created_at,
|
||||||
|
searchable_fields=[ # default fields for full-text search
|
||||||
|
Article.title,
|
||||||
|
Article.body,
|
||||||
|
(Article.category, Category.name),
|
||||||
|
],
|
||||||
|
facet_fields=[ # fields exposed as filter dropdowns
|
||||||
|
Article.status,
|
||||||
|
(Article.category, Category.name),
|
||||||
|
],
|
||||||
|
order_fields=[ # fields exposed for client-driven ordering
|
||||||
|
Article.title,
|
||||||
|
Article.created_at,
|
||||||
|
],
|
||||||
|
)
|
||||||
17
docs_src/examples/pagination_search/db.py
Normal file
17
docs_src/examples/pagination_search/db.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from fastapi_toolsets.db import create_db_context, create_db_dependency
|
||||||
|
|
||||||
|
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/postgres"
|
||||||
|
|
||||||
|
engine = create_async_engine(url=DATABASE_URL, future=True)
|
||||||
|
async_session_maker = async_sessionmaker(bind=engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
get_db = create_db_dependency(session_maker=async_session_maker)
|
||||||
|
get_db_context = create_db_context(session_maker=async_session_maker)
|
||||||
|
|
||||||
|
|
||||||
|
SessionDep = Annotated[AsyncSession, Depends(get_db)]
|
||||||
34
docs_src/examples/pagination_search/models.py
Normal file
34
docs_src/examples/pagination_search/models.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
import uuid
|
||||||
|
|
||||||
|
from sqlalchemy import Boolean, ForeignKey, String, Text
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from fastapi_toolsets.models import CreatedAtMixin
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Category(Base):
|
||||||
|
__tablename__ = "categories"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(primary_key=True, default=uuid.uuid4)
|
||||||
|
name: Mapped[str] = mapped_column(String(64), unique=True)
|
||||||
|
|
||||||
|
articles: Mapped[list["Article"]] = relationship(back_populates="category")
|
||||||
|
|
||||||
|
|
||||||
|
class Article(Base, CreatedAtMixin):
|
||||||
|
__tablename__ = "articles"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(primary_key=True, default=uuid.uuid4)
|
||||||
|
title: Mapped[str] = mapped_column(String(256))
|
||||||
|
body: Mapped[str] = mapped_column(Text)
|
||||||
|
status: Mapped[str] = mapped_column(String(32))
|
||||||
|
published: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||||
|
category_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
|
ForeignKey("categories.id"), nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
|
category: Mapped["Category | None"] = relationship(back_populates="articles")
|
||||||
79
docs_src/examples/pagination_search/routes.py
Normal file
79
docs_src/examples/pagination_search/routes.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
from fastapi_toolsets.schemas import (
|
||||||
|
CursorPaginatedResponse,
|
||||||
|
OffsetPaginatedResponse,
|
||||||
|
PaginatedResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .crud import ArticleCrud
|
||||||
|
from .db import SessionDep
|
||||||
|
from .models import Article
|
||||||
|
from .schemas import ArticleRead
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/articles")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/offset")
|
||||||
|
async def list_articles_offset(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[
|
||||||
|
dict,
|
||||||
|
Depends(
|
||||||
|
ArticleCrud.offset_paginate_params(
|
||||||
|
default_page_size=20,
|
||||||
|
max_page_size=100,
|
||||||
|
default_order_field=Article.created_at,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
],
|
||||||
|
) -> OffsetPaginatedResponse[ArticleRead]:
|
||||||
|
return await ArticleCrud.offset_paginate(
|
||||||
|
session=session,
|
||||||
|
**params,
|
||||||
|
schema=ArticleRead,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/cursor")
|
||||||
|
async def list_articles_cursor(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[
|
||||||
|
dict,
|
||||||
|
Depends(
|
||||||
|
ArticleCrud.cursor_paginate_params(
|
||||||
|
default_page_size=20,
|
||||||
|
max_page_size=100,
|
||||||
|
default_order_field=Article.created_at,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
],
|
||||||
|
) -> CursorPaginatedResponse[ArticleRead]:
|
||||||
|
return await ArticleCrud.cursor_paginate(
|
||||||
|
session=session,
|
||||||
|
**params,
|
||||||
|
schema=ArticleRead,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/")
|
||||||
|
async def list_articles(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[
|
||||||
|
dict,
|
||||||
|
Depends(
|
||||||
|
ArticleCrud.paginate_params(
|
||||||
|
default_page_size=20,
|
||||||
|
max_page_size=100,
|
||||||
|
default_order_field=Article.created_at,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
],
|
||||||
|
) -> PaginatedResponse[ArticleRead]:
|
||||||
|
return await ArticleCrud.paginate(
|
||||||
|
session,
|
||||||
|
**params,
|
||||||
|
schema=ArticleRead,
|
||||||
|
)
|
||||||
13
docs_src/examples/pagination_search/schemas.py
Normal file
13
docs_src/examples/pagination_search/schemas.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
import datetime
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from fastapi_toolsets.schemas import PydanticBase
|
||||||
|
|
||||||
|
|
||||||
|
class ArticleRead(PydanticBase):
|
||||||
|
id: uuid.UUID
|
||||||
|
created_at: datetime.datetime
|
||||||
|
title: str
|
||||||
|
status: str
|
||||||
|
published: bool
|
||||||
|
category_id: uuid.UUID | None
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "fastapi-toolsets"
|
name = "fastapi-toolsets"
|
||||||
version = "0.1.0"
|
version = "3.0.1"
|
||||||
description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL"
|
description = "Production-ready utilities for FastAPI applications"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
license-files = ["LICENSE"]
|
license-files = ["LICENSE"]
|
||||||
@@ -11,7 +11,7 @@ authors = [
|
|||||||
]
|
]
|
||||||
keywords = ["fastapi", "sqlalchemy", "postgresql"]
|
keywords = ["fastapi", "sqlalchemy", "postgresql"]
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Development Status :: 4 - Beta",
|
"Development Status :: 5 - Production/Stable",
|
||||||
"Framework :: AsyncIO",
|
"Framework :: AsyncIO",
|
||||||
"Framework :: FastAPI",
|
"Framework :: FastAPI",
|
||||||
"Framework :: Pydantic",
|
"Framework :: Pydantic",
|
||||||
@@ -24,18 +24,17 @@ classifiers = [
|
|||||||
"Programming Language :: Python :: 3.11",
|
"Programming Language :: Python :: 3.11",
|
||||||
"Programming Language :: Python :: 3.12",
|
"Programming Language :: Python :: 3.12",
|
||||||
"Programming Language :: Python :: 3.13",
|
"Programming Language :: Python :: 3.13",
|
||||||
|
"Programming Language :: Python :: 3.14",
|
||||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||||
"Topic :: Software Development :: Libraries",
|
"Topic :: Software Development :: Libraries",
|
||||||
"Topic :: Software Development",
|
"Topic :: Software Development",
|
||||||
"Typing :: Typed",
|
"Typing :: Typed",
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"fastapi>=0.100.0",
|
|
||||||
"sqlalchemy[asyncio]>=2.0",
|
|
||||||
"asyncpg>=0.29.0",
|
"asyncpg>=0.29.0",
|
||||||
|
"fastapi>=0.100.0",
|
||||||
"pydantic>=2.0",
|
"pydantic>=2.0",
|
||||||
"typer>=0.9.0",
|
"sqlalchemy[asyncio]>=2.0",
|
||||||
"httpx>=0.25.0",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
@@ -45,23 +44,53 @@ Repository = "https://github.com/d3vyce/fastapi-toolsets"
|
|||||||
Issues = "https://github.com/d3vyce/fastapi-toolsets/issues"
|
Issues = "https://github.com/d3vyce/fastapi-toolsets/issues"
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
test = [
|
cli = [
|
||||||
"pytest>=8.0.0",
|
"typer>=0.9.0",
|
||||||
"pytest-anyio>=0.0.0",
|
|
||||||
"coverage>=7.0.0",
|
|
||||||
"pytest-cov>=4.0.0",
|
|
||||||
]
|
]
|
||||||
dev = [
|
metrics = [
|
||||||
"fastapi-toolsets[test]",
|
"prometheus_client>=0.20.0",
|
||||||
"ruff>=0.1.0",
|
]
|
||||||
"ty>=0.0.1a0",
|
pytest = [
|
||||||
|
"httpx>=0.25.0",
|
||||||
|
"pytest-xdist>=3.0.0",
|
||||||
|
"pytest>=8.0.0",
|
||||||
|
]
|
||||||
|
all = [
|
||||||
|
"fastapi-toolsets[cli,metrics,pytest]",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
fastapi-toolsets = "fastapi_toolsets.cli:app"
|
manager = "fastapi_toolsets.cli.app:cli"
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
dev = [
|
||||||
|
{include-group = "tests"},
|
||||||
|
{include-group = "docs"},
|
||||||
|
{include-group = "docs-src"},
|
||||||
|
"fastapi-toolsets[all]",
|
||||||
|
"prek>=0.3.8",
|
||||||
|
"ruff>=0.1.0",
|
||||||
|
"ty>=0.0.1a0",
|
||||||
|
]
|
||||||
|
tests = [
|
||||||
|
"coverage>=7.0.0",
|
||||||
|
"httpx>=0.25.0",
|
||||||
|
"pytest-anyio>=0.0.0",
|
||||||
|
"pytest-cov>=4.0.0",
|
||||||
|
"pytest-xdist>=3.0.0",
|
||||||
|
"pytest>=8.0.0",
|
||||||
|
]
|
||||||
|
docs = [
|
||||||
|
"mike",
|
||||||
|
"mkdocstrings-python>=2.0.2",
|
||||||
|
"zensical>=0.0.30",
|
||||||
|
]
|
||||||
|
docs-src = [
|
||||||
|
"bcrypt>=4.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["uv_build>=0.9.26,<0.10.0"]
|
requires = ["uv_build>=0.10,<0.12.0"]
|
||||||
build-backend = "uv_build"
|
build-backend = "uv_build"
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
@@ -80,3 +109,6 @@ exclude_lines = [
|
|||||||
"if TYPE_CHECKING:",
|
"if TYPE_CHECKING:",
|
||||||
"raise NotImplementedError",
|
"raise NotImplementedError",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.uv.sources]
|
||||||
|
mike = { git = "https://github.com/squidfunk/mike.git", tag = "2.2.0+zensical-0.1.0" }
|
||||||
|
|||||||
@@ -21,4 +21,4 @@ Example usage:
|
|||||||
return Response(data={"user": user.username}, message="Success")
|
return Response(data={"user": user.username}, message="Success")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "3.0.1"
|
||||||
|
|||||||
9
src/fastapi_toolsets/_imports.py
Normal file
9
src/fastapi_toolsets/_imports.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
"""Optional dependency helpers."""
|
||||||
|
|
||||||
|
|
||||||
|
def require_extra(package: str, extra: str) -> None:
|
||||||
|
"""Raise *ImportError* with an actionable install instruction."""
|
||||||
|
raise ImportError(
|
||||||
|
f"'{package}' is required to use this module. "
|
||||||
|
f"Install it with: pip install fastapi-toolsets[{extra}]"
|
||||||
|
)
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
"""CLI for FastAPI projects."""
|
"""CLI for FastAPI projects."""
|
||||||
|
|
||||||
from .app import app, register_command
|
from .utils import async_command
|
||||||
|
|
||||||
__all__ = ["app", "register_command"]
|
__all__ = ["async_command"]
|
||||||
|
|||||||
@@ -1,97 +1,37 @@
|
|||||||
"""Main CLI application."""
|
"""Main CLI application."""
|
||||||
|
|
||||||
import importlib.util
|
try:
|
||||||
import sys
|
import typer
|
||||||
from pathlib import Path
|
except ImportError:
|
||||||
from typing import Annotated
|
from .._imports import require_extra
|
||||||
|
|
||||||
import typer
|
require_extra(package="typer", extra="cli")
|
||||||
|
|
||||||
from .commands import fixtures
|
from ..logger import configure_logging
|
||||||
|
from .config import get_custom_cli
|
||||||
|
from .pyproject import load_pyproject
|
||||||
|
|
||||||
app = typer.Typer(
|
# Use custom CLI if configured, otherwise create default one
|
||||||
name="fastapi-utils",
|
_custom_cli = get_custom_cli()
|
||||||
help="CLI utilities for FastAPI projects.",
|
|
||||||
no_args_is_help=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Register built-in commands
|
if _custom_cli is not None:
|
||||||
app.add_typer(fixtures.app, name="fixtures")
|
cli = _custom_cli
|
||||||
|
else:
|
||||||
|
cli = typer.Typer(
|
||||||
|
name="manager",
|
||||||
|
help="CLI utilities for FastAPI projects.",
|
||||||
|
no_args_is_help=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
_config = load_pyproject()
|
||||||
|
if _config.get("fixtures") and _config.get("db_context"):
|
||||||
|
from .commands.fixtures import fixture_cli
|
||||||
|
|
||||||
|
cli.add_typer(fixture_cli, name="fixtures")
|
||||||
|
|
||||||
|
|
||||||
def register_command(command: typer.Typer, name: str) -> None:
|
@cli.callback()
|
||||||
"""Register a custom command group.
|
def main(ctx: typer.Context) -> None:
|
||||||
|
|
||||||
Args:
|
|
||||||
command: Typer app for the command group
|
|
||||||
name: Name for the command group
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# In your project's cli.py:
|
|
||||||
import typer
|
|
||||||
from fastapi_toolsets.cli import app, register_command
|
|
||||||
|
|
||||||
my_commands = typer.Typer()
|
|
||||||
|
|
||||||
@my_commands.command()
|
|
||||||
def seed():
|
|
||||||
'''Seed the database.'''
|
|
||||||
...
|
|
||||||
|
|
||||||
register_command(my_commands, "db")
|
|
||||||
# Now available as: fastapi-utils db seed
|
|
||||||
"""
|
|
||||||
app.add_typer(command, name=name)
|
|
||||||
|
|
||||||
|
|
||||||
@app.callback()
|
|
||||||
def main(
|
|
||||||
ctx: typer.Context,
|
|
||||||
config: Annotated[
|
|
||||||
Path | None,
|
|
||||||
typer.Option(
|
|
||||||
"--config",
|
|
||||||
"-c",
|
|
||||||
help="Path to project config file (Python module with fixtures registry).",
|
|
||||||
envvar="FASTAPI_TOOLSETS_CONFIG",
|
|
||||||
),
|
|
||||||
] = None,
|
|
||||||
) -> None:
|
|
||||||
"""FastAPI utilities CLI."""
|
"""FastAPI utilities CLI."""
|
||||||
|
configure_logging()
|
||||||
ctx.ensure_object(dict)
|
ctx.ensure_object(dict)
|
||||||
|
|
||||||
if config:
|
|
||||||
ctx.obj["config_path"] = config
|
|
||||||
# Load the config module
|
|
||||||
config_module = _load_module_from_path(config)
|
|
||||||
ctx.obj["config_module"] = config_module
|
|
||||||
|
|
||||||
|
|
||||||
def _load_module_from_path(path: Path) -> object:
|
|
||||||
"""Load a Python module from a file path.
|
|
||||||
|
|
||||||
Handles both absolute and relative imports by adding the config's
|
|
||||||
parent directory to sys.path temporarily.
|
|
||||||
"""
|
|
||||||
path = path.resolve()
|
|
||||||
|
|
||||||
# Add the parent directory to sys.path to support relative imports
|
|
||||||
parent_dir = str(
|
|
||||||
path.parent.parent
|
|
||||||
) # Go up two levels (e.g., from app/cli_config.py to project root)
|
|
||||||
if parent_dir not in sys.path:
|
|
||||||
sys.path.insert(0, parent_dir)
|
|
||||||
|
|
||||||
# Also add immediate parent for direct module imports
|
|
||||||
immediate_parent = str(path.parent)
|
|
||||||
if immediate_parent not in sys.path:
|
|
||||||
sys.path.insert(0, immediate_parent)
|
|
||||||
|
|
||||||
spec = importlib.util.spec_from_file_location("config", path)
|
|
||||||
if spec is None or spec.loader is None:
|
|
||||||
raise typer.BadParameter(f"Cannot load module from {path}")
|
|
||||||
|
|
||||||
module = importlib.util.module_from_spec(spec)
|
|
||||||
sys.modules["config"] = module
|
|
||||||
spec.loader.exec_module(module)
|
|
||||||
return module
|
|
||||||
|
|||||||
@@ -1,138 +1,66 @@
|
|||||||
"""Fixture management commands."""
|
"""Fixture management commands."""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
from ...fixtures import Context, FixtureRegistry, LoadStrategy, load_fixtures_by_context
|
from ...fixtures import Context, LoadStrategy, load_fixtures_by_context
|
||||||
|
from ..config import get_db_context, get_fixtures_registry
|
||||||
|
from ..utils import async_command
|
||||||
|
|
||||||
app = typer.Typer(
|
fixture_cli = typer.Typer(
|
||||||
name="fixtures",
|
name="fixtures",
|
||||||
help="Manage database fixtures.",
|
help="Manage database fixtures.",
|
||||||
no_args_is_help=True,
|
no_args_is_help=True,
|
||||||
)
|
)
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
def _get_registry(ctx: typer.Context) -> FixtureRegistry:
|
@fixture_cli.command("list")
|
||||||
"""Get fixture registry from context."""
|
|
||||||
config = ctx.obj.get("config_module") if ctx.obj else None
|
|
||||||
if config is None:
|
|
||||||
raise typer.BadParameter(
|
|
||||||
"No config provided. Use --config to specify a config file with a 'fixtures' registry."
|
|
||||||
)
|
|
||||||
|
|
||||||
registry = getattr(config, "fixtures", None)
|
|
||||||
if registry is None:
|
|
||||||
raise typer.BadParameter(
|
|
||||||
"Config module must have a 'fixtures' attribute (FixtureRegistry instance)."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not isinstance(registry, FixtureRegistry):
|
|
||||||
raise typer.BadParameter(
|
|
||||||
f"'fixtures' must be a FixtureRegistry instance, got {type(registry).__name__}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return registry
|
|
||||||
|
|
||||||
|
|
||||||
def _get_db_context(ctx: typer.Context):
|
|
||||||
"""Get database context manager from config."""
|
|
||||||
config = ctx.obj.get("config_module") if ctx.obj else None
|
|
||||||
if config is None:
|
|
||||||
raise typer.BadParameter("No config provided.")
|
|
||||||
|
|
||||||
get_db_context = getattr(config, "get_db_context", None)
|
|
||||||
if get_db_context is None:
|
|
||||||
raise typer.BadParameter("Config module must have a 'get_db_context' function.")
|
|
||||||
|
|
||||||
return get_db_context
|
|
||||||
|
|
||||||
|
|
||||||
@app.command("list")
|
|
||||||
def list_fixtures(
|
def list_fixtures(
|
||||||
ctx: typer.Context,
|
ctx: typer.Context,
|
||||||
context: Annotated[
|
context: Annotated[
|
||||||
str | None,
|
Context | None,
|
||||||
typer.Option(
|
typer.Option(
|
||||||
"--context",
|
"--context",
|
||||||
"-c",
|
"-c",
|
||||||
help="Filter by context (base, production, development, testing).",
|
help="Filter by context.",
|
||||||
),
|
),
|
||||||
] = None,
|
] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""List all registered fixtures."""
|
"""List all registered fixtures."""
|
||||||
registry = _get_registry(ctx)
|
registry = get_fixtures_registry()
|
||||||
|
fixtures = registry.get_by_context(context.value) if context else registry.get_all()
|
||||||
if context:
|
|
||||||
fixtures = registry.get_by_context(context)
|
|
||||||
else:
|
|
||||||
fixtures = registry.get_all()
|
|
||||||
|
|
||||||
if not fixtures:
|
if not fixtures:
|
||||||
typer.echo("No fixtures found.")
|
print("No fixtures found.")
|
||||||
return
|
return
|
||||||
|
|
||||||
typer.echo(f"\n{'Name':<30} {'Contexts':<30} {'Dependencies'}")
|
table = Table("Name", "Contexts", "Dependencies")
|
||||||
typer.echo("-" * 80)
|
|
||||||
|
|
||||||
for fixture in fixtures:
|
for fixture in fixtures:
|
||||||
contexts = ", ".join(fixture.contexts)
|
contexts = ", ".join(fixture.contexts)
|
||||||
deps = ", ".join(fixture.depends_on) if fixture.depends_on else "-"
|
deps = ", ".join(fixture.depends_on) if fixture.depends_on else "-"
|
||||||
typer.echo(f"{fixture.name:<30} {contexts:<30} {deps}")
|
table.add_row(fixture.name, contexts, deps)
|
||||||
|
|
||||||
typer.echo(f"\nTotal: {len(fixtures)} fixture(s)")
|
console.print(table)
|
||||||
|
print(f"\nTotal: {len(fixtures)} fixture(s)")
|
||||||
|
|
||||||
|
|
||||||
@app.command("graph")
|
@fixture_cli.command("load")
|
||||||
def show_graph(
|
@async_command
|
||||||
ctx: typer.Context,
|
async def load(
|
||||||
fixture_name: Annotated[
|
|
||||||
str | None,
|
|
||||||
typer.Argument(help="Show dependencies for a specific fixture."),
|
|
||||||
] = None,
|
|
||||||
) -> None:
|
|
||||||
"""Show fixture dependency graph."""
|
|
||||||
registry = _get_registry(ctx)
|
|
||||||
|
|
||||||
if fixture_name:
|
|
||||||
try:
|
|
||||||
order = registry.resolve_dependencies(fixture_name)
|
|
||||||
typer.echo(f"\nDependency chain for '{fixture_name}':\n")
|
|
||||||
for i, name in enumerate(order):
|
|
||||||
indent = " " * i
|
|
||||||
arrow = "└─> " if i > 0 else ""
|
|
||||||
typer.echo(f"{indent}{arrow}{name}")
|
|
||||||
except KeyError:
|
|
||||||
typer.echo(f"Fixture '{fixture_name}' not found.", err=True)
|
|
||||||
raise typer.Exit(1)
|
|
||||||
else:
|
|
||||||
# Show full graph
|
|
||||||
fixtures = registry.get_all()
|
|
||||||
|
|
||||||
typer.echo("\nFixture Dependency Graph:\n")
|
|
||||||
for fixture in fixtures:
|
|
||||||
deps = (
|
|
||||||
f" -> [{', '.join(fixture.depends_on)}]" if fixture.depends_on else ""
|
|
||||||
)
|
|
||||||
typer.echo(f" {fixture.name}{deps}")
|
|
||||||
|
|
||||||
|
|
||||||
@app.command("load")
|
|
||||||
def load(
|
|
||||||
ctx: typer.Context,
|
ctx: typer.Context,
|
||||||
contexts: Annotated[
|
contexts: Annotated[
|
||||||
list[str] | None,
|
list[Context] | None,
|
||||||
typer.Argument(
|
typer.Argument(help="Contexts to load."),
|
||||||
help="Contexts to load (base, production, development, testing)."
|
|
||||||
),
|
|
||||||
] = None,
|
] = None,
|
||||||
strategy: Annotated[
|
strategy: Annotated[
|
||||||
str,
|
LoadStrategy,
|
||||||
typer.Option(
|
typer.Option("--strategy", "-s", help="Load strategy."),
|
||||||
"--strategy", "-s", help="Load strategy: merge, insert, skip_existing."
|
] = LoadStrategy.MERGE,
|
||||||
),
|
|
||||||
] = "merge",
|
|
||||||
dry_run: Annotated[
|
dry_run: Annotated[
|
||||||
bool,
|
bool,
|
||||||
typer.Option(
|
typer.Option(
|
||||||
@@ -141,85 +69,32 @@ def load(
|
|||||||
] = False,
|
] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Load fixtures into the database."""
|
"""Load fixtures into the database."""
|
||||||
registry = _get_registry(ctx)
|
registry = get_fixtures_registry()
|
||||||
get_db_context = _get_db_context(ctx)
|
db_context = get_db_context()
|
||||||
|
|
||||||
# Parse contexts
|
context_list = list(contexts) if contexts else [Context.BASE]
|
||||||
if contexts:
|
|
||||||
context_list = contexts
|
|
||||||
else:
|
|
||||||
context_list = [Context.BASE]
|
|
||||||
|
|
||||||
# Parse strategy
|
|
||||||
try:
|
|
||||||
load_strategy = LoadStrategy(strategy)
|
|
||||||
except ValueError:
|
|
||||||
typer.echo(
|
|
||||||
f"Invalid strategy: {strategy}. Use: merge, insert, skip_existing", err=True
|
|
||||||
)
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
# Resolve what will be loaded
|
|
||||||
ordered = registry.resolve_context_dependencies(*context_list)
|
ordered = registry.resolve_context_dependencies(*context_list)
|
||||||
|
|
||||||
if not ordered:
|
if not ordered:
|
||||||
typer.echo("No fixtures to load for the specified context(s).")
|
print("No fixtures to load for the specified context(s).")
|
||||||
return
|
return
|
||||||
|
|
||||||
typer.echo(f"\nFixtures to load ({load_strategy.value} strategy):")
|
print(f"\nFixtures to load ({strategy.value} strategy):")
|
||||||
for name in ordered:
|
for name in ordered:
|
||||||
fixture = registry.get(name)
|
fixture = registry.get(name)
|
||||||
instances = list(fixture.func())
|
instances = list(fixture.func())
|
||||||
model_name = type(instances[0]).__name__ if instances else "?"
|
model_name = type(instances[0]).__name__ if instances else "?"
|
||||||
typer.echo(f" - {name}: {len(instances)} {model_name}(s)")
|
print(f" - {name}: {len(instances)} {model_name}(s)")
|
||||||
|
|
||||||
if dry_run:
|
if dry_run:
|
||||||
typer.echo("\n[Dry run - no changes made]")
|
print("\n[Dry run - no changes made]")
|
||||||
return
|
return
|
||||||
|
|
||||||
typer.echo("\nLoading...")
|
async with db_context() as session:
|
||||||
|
result = await load_fixtures_by_context(
|
||||||
async def do_load():
|
session, registry, *context_list, strategy=strategy
|
||||||
async with get_db_context() as session:
|
)
|
||||||
result = await load_fixtures_by_context(
|
|
||||||
session, registry, *context_list, strategy=load_strategy
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
result = asyncio.run(do_load())
|
|
||||||
|
|
||||||
total = sum(len(items) for items in result.values())
|
total = sum(len(items) for items in result.values())
|
||||||
typer.echo(f"\nLoaded {total} record(s) successfully.")
|
print(f"\nLoaded {total} record(s) successfully.")
|
||||||
|
|
||||||
|
|
||||||
@app.command("show")
|
|
||||||
def show_fixture(
|
|
||||||
ctx: typer.Context,
|
|
||||||
name: Annotated[str, typer.Argument(help="Fixture name to show.")],
|
|
||||||
) -> None:
|
|
||||||
"""Show details of a specific fixture."""
|
|
||||||
registry = _get_registry(ctx)
|
|
||||||
|
|
||||||
try:
|
|
||||||
fixture = registry.get(name)
|
|
||||||
except KeyError:
|
|
||||||
typer.echo(f"Fixture '{name}' not found.", err=True)
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
typer.echo(f"\nFixture: {fixture.name}")
|
|
||||||
typer.echo(f"Contexts: {', '.join(fixture.contexts)}")
|
|
||||||
typer.echo(
|
|
||||||
f"Dependencies: {', '.join(fixture.depends_on) if fixture.depends_on else 'None'}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Show instances
|
|
||||||
instances = list(fixture.func())
|
|
||||||
if instances:
|
|
||||||
model_name = type(instances[0]).__name__
|
|
||||||
typer.echo(f"\nInstances ({len(instances)} {model_name}):")
|
|
||||||
for instance in instances[:10]: # Limit to 10
|
|
||||||
typer.echo(f" - {instance!r}")
|
|
||||||
if len(instances) > 10:
|
|
||||||
typer.echo(f" ... and {len(instances) - 10} more")
|
|
||||||
else:
|
|
||||||
typer.echo("\nNo instances (empty fixture)")
|
|
||||||
|
|||||||
125
src/fastapi_toolsets/cli/config.py
Normal file
125
src/fastapi_toolsets/cli/config.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
"""CLI configuration and dynamic imports."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
from typing import TYPE_CHECKING, Any, Literal, overload
|
||||||
|
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from .pyproject import find_pyproject, load_pyproject
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..fixtures import FixtureRegistry
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_project_in_path():
|
||||||
|
"""Add project root to sys.path if not installed in editable mode."""
|
||||||
|
pyproject = find_pyproject()
|
||||||
|
if pyproject:
|
||||||
|
project_root = str(pyproject.parent)
|
||||||
|
if project_root not in sys.path:
|
||||||
|
sys.path.insert(0, project_root)
|
||||||
|
|
||||||
|
|
||||||
|
def import_from_string(import_path: str) -> Any:
|
||||||
|
"""Import an object from a dotted string path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
import_path: Import path in ``"module.submodule:attribute"`` format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The imported attribute
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
typer.BadParameter: If the import path is invalid or import fails
|
||||||
|
"""
|
||||||
|
if ":" not in import_path:
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"Invalid import path '{import_path}'. Expected format: 'module:attribute'"
|
||||||
|
)
|
||||||
|
|
||||||
|
module_path, attr_name = import_path.rsplit(":", 1)
|
||||||
|
|
||||||
|
_ensure_project_in_path()
|
||||||
|
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(module_path)
|
||||||
|
except ImportError as e:
|
||||||
|
raise typer.BadParameter(f"Cannot import module '{module_path}': {e}")
|
||||||
|
|
||||||
|
if not hasattr(module, attr_name):
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"Module '{module_path}' has no attribute '{attr_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
return getattr(module, attr_name)
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_config_value(key: str, required: Literal[True]) -> Any: ... # pragma: no cover
|
||||||
|
@overload
|
||||||
|
def get_config_value(
|
||||||
|
key: str, required: bool = False
|
||||||
|
) -> Any | None: ... # pragma: no cover
|
||||||
|
def get_config_value(key: str, required: bool = False) -> Any | None:
|
||||||
|
"""Get a configuration value from pyproject.toml.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The configuration key in [tool.fastapi-toolsets].
|
||||||
|
required: If True, raises an error when the key is missing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The configuration value, or None if not found and not required.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
typer.BadParameter: If required=True and the key is missing.
|
||||||
|
"""
|
||||||
|
config = load_pyproject()
|
||||||
|
value = config.get(key)
|
||||||
|
|
||||||
|
if required and value is None:
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"No '{key}' configured. "
|
||||||
|
f"Add '{key}' to [tool.fastapi-toolsets] in pyproject.toml."
|
||||||
|
)
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def get_fixtures_registry() -> FixtureRegistry:
|
||||||
|
"""Import and return the fixtures registry from config."""
|
||||||
|
from ..fixtures import FixtureRegistry
|
||||||
|
|
||||||
|
import_path = get_config_value("fixtures", required=True)
|
||||||
|
registry = import_from_string(import_path)
|
||||||
|
|
||||||
|
if not isinstance(registry, FixtureRegistry):
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"'fixtures' must be a FixtureRegistry instance, got {type(registry).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return registry
|
||||||
|
|
||||||
|
|
||||||
|
def get_db_context() -> Any:
|
||||||
|
"""Import and return the db_context function from config."""
|
||||||
|
import_path = get_config_value("db_context", required=True)
|
||||||
|
return import_from_string(import_path)
|
||||||
|
|
||||||
|
|
||||||
|
def get_custom_cli() -> typer.Typer | None:
|
||||||
|
"""Import and return the custom CLI Typer instance from config."""
|
||||||
|
import_path = get_config_value("custom_cli")
|
||||||
|
if not import_path:
|
||||||
|
return None
|
||||||
|
|
||||||
|
custom = import_from_string(import_path)
|
||||||
|
|
||||||
|
if not isinstance(custom, typer.Typer):
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"'custom_cli' must be a Typer instance, got {type(custom).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return custom
|
||||||
49
src/fastapi_toolsets/cli/pyproject.py
Normal file
49
src/fastapi_toolsets/cli/pyproject.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
"""Pyproject.toml discovery and loading."""
|
||||||
|
|
||||||
|
import tomllib
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
TOOL_NAME = "fastapi-toolsets"
|
||||||
|
|
||||||
|
|
||||||
|
def find_pyproject(start_path: Path | None = None) -> Path | None:
|
||||||
|
"""Find pyproject.toml by walking up the directory tree.
|
||||||
|
|
||||||
|
Similar to how pytest, black, and ruff discover their config files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_path: Directory to start searching from. Defaults to cwd.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to pyproject.toml, or None if not found.
|
||||||
|
"""
|
||||||
|
path = (start_path or Path.cwd()).resolve()
|
||||||
|
|
||||||
|
for directory in [path, *path.parents]:
|
||||||
|
pyproject = directory / "pyproject.toml"
|
||||||
|
if pyproject.is_file():
|
||||||
|
return pyproject
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def load_pyproject(path: Path | None = None) -> dict:
|
||||||
|
"""Load tool configuration from pyproject.toml.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Explicit path to pyproject.toml. If None, searches up from cwd.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The [tool.fastapi-toolsets] section as a dict, or empty dict if not found.
|
||||||
|
"""
|
||||||
|
pyproject_path = path or find_pyproject()
|
||||||
|
|
||||||
|
if not pyproject_path:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(pyproject_path, "rb") as f:
|
||||||
|
data = tomllib.load(f)
|
||||||
|
return data.get("tool", {}).get(TOOL_NAME, {})
|
||||||
|
except (OSError, tomllib.TOMLDecodeError):
|
||||||
|
return {}
|
||||||
29
src/fastapi_toolsets/cli/utils.py
Normal file
29
src/fastapi_toolsets/cli/utils.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
"""CLI utility functions."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import functools
|
||||||
|
from collections.abc import Callable, Coroutine
|
||||||
|
from typing import Any, ParamSpec, TypeVar
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def async_command(func: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, T]:
|
||||||
|
"""Decorator to run an async function as a sync CLI command.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
@fixture_cli.command("load")
|
||||||
|
@async_command
|
||||||
|
async def load(ctx: typer.Context) -> None:
|
||||||
|
async with get_db_context() as session:
|
||||||
|
await load_fixtures(session, registry)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||||
|
return asyncio.run(func(*args, **kwargs))
|
||||||
|
|
||||||
|
return wrapper
|
||||||
@@ -1,378 +0,0 @@
|
|||||||
"""Generic async CRUD operations for SQLAlchemy models."""
|
|
||||||
|
|
||||||
from collections.abc import Sequence
|
|
||||||
from typing import Any, ClassVar, Generic, Self, TypeVar, cast
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from sqlalchemy import and_, func, select
|
|
||||||
from sqlalchemy import delete as sql_delete
|
|
||||||
from sqlalchemy.dialects.postgresql import insert
|
|
||||||
from sqlalchemy.exc import NoResultFound
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
|
||||||
from sqlalchemy.sql.roles import WhereHavingRole
|
|
||||||
|
|
||||||
from .db import get_transaction
|
|
||||||
from .exceptions import NotFoundError
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"AsyncCrud",
|
|
||||||
"CrudFactory",
|
|
||||||
]
|
|
||||||
|
|
||||||
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncCrud(Generic[ModelType]):
|
|
||||||
"""Generic async CRUD operations for SQLAlchemy models.
|
|
||||||
|
|
||||||
Subclass this and set the `model` class variable, or use `CrudFactory`.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
class UserCrud(AsyncCrud[User]):
|
|
||||||
model = User
|
|
||||||
|
|
||||||
# Or use the factory:
|
|
||||||
UserCrud = CrudFactory(User)
|
|
||||||
|
|
||||||
# Then use it:
|
|
||||||
user = await UserCrud.get(session, [User.id == 1])
|
|
||||||
users = await UserCrud.get_multi(session, limit=10)
|
|
||||||
"""
|
|
||||||
|
|
||||||
model: ClassVar[type[DeclarativeBase]]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def create(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
obj: BaseModel,
|
|
||||||
) -> ModelType:
|
|
||||||
"""Create a new record in the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
obj: Pydantic model with data to create
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created model instance
|
|
||||||
"""
|
|
||||||
async with get_transaction(session):
|
|
||||||
db_model = cls.model(**obj.model_dump())
|
|
||||||
session.add(db_model)
|
|
||||||
await session.refresh(db_model)
|
|
||||||
return cast(ModelType, db_model)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def get(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
filters: list[Any],
|
|
||||||
*,
|
|
||||||
with_for_update: bool = False,
|
|
||||||
load_options: list[Any] | None = None,
|
|
||||||
) -> ModelType:
|
|
||||||
"""Get exactly one record. Raises NotFoundError if not found.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
filters: List of SQLAlchemy filter conditions
|
|
||||||
with_for_update: Lock the row for update
|
|
||||||
load_options: SQLAlchemy loader options (e.g., selectinload)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Model instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NotFoundError: If no record found
|
|
||||||
MultipleResultsFound: If more than one record found
|
|
||||||
"""
|
|
||||||
q = select(cls.model).where(and_(*filters))
|
|
||||||
if load_options:
|
|
||||||
q = q.options(*load_options)
|
|
||||||
if with_for_update:
|
|
||||||
q = q.with_for_update()
|
|
||||||
result = await session.execute(q)
|
|
||||||
item = result.unique().scalar_one_or_none()
|
|
||||||
if not item:
|
|
||||||
raise NotFoundError()
|
|
||||||
return cast(ModelType, item)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def first(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
filters: list[Any] | None = None,
|
|
||||||
*,
|
|
||||||
load_options: list[Any] | None = None,
|
|
||||||
) -> ModelType | None:
|
|
||||||
"""Get the first matching record, or None.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
filters: List of SQLAlchemy filter conditions
|
|
||||||
load_options: SQLAlchemy loader options
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Model instance or None
|
|
||||||
"""
|
|
||||||
q = select(cls.model)
|
|
||||||
if filters:
|
|
||||||
q = q.where(and_(*filters))
|
|
||||||
if load_options:
|
|
||||||
q = q.options(*load_options)
|
|
||||||
result = await session.execute(q)
|
|
||||||
return cast(ModelType | None, result.unique().scalars().first())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def get_multi(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
*,
|
|
||||||
filters: list[Any] | None = None,
|
|
||||||
load_options: list[Any] | None = None,
|
|
||||||
order_by: Any | None = None,
|
|
||||||
limit: int | None = None,
|
|
||||||
offset: int | None = None,
|
|
||||||
) -> Sequence[ModelType]:
|
|
||||||
"""Get multiple records from the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
filters: List of SQLAlchemy filter conditions
|
|
||||||
load_options: SQLAlchemy loader options
|
|
||||||
order_by: Column or list of columns to order by
|
|
||||||
limit: Max number of rows to return
|
|
||||||
offset: Rows to skip
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of model instances
|
|
||||||
"""
|
|
||||||
q = select(cls.model)
|
|
||||||
if filters:
|
|
||||||
q = q.where(and_(*filters))
|
|
||||||
if load_options:
|
|
||||||
q = q.options(*load_options)
|
|
||||||
if order_by is not None:
|
|
||||||
q = q.order_by(order_by)
|
|
||||||
if offset is not None:
|
|
||||||
q = q.offset(offset)
|
|
||||||
if limit is not None:
|
|
||||||
q = q.limit(limit)
|
|
||||||
result = await session.execute(q)
|
|
||||||
return cast(Sequence[ModelType], result.unique().scalars().all())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def update(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
obj: BaseModel,
|
|
||||||
filters: list[Any],
|
|
||||||
*,
|
|
||||||
exclude_unset: bool = True,
|
|
||||||
exclude_none: bool = False,
|
|
||||||
) -> ModelType:
|
|
||||||
"""Update a record in the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
obj: Pydantic model with update data
|
|
||||||
filters: List of SQLAlchemy filter conditions
|
|
||||||
exclude_unset: Exclude fields not explicitly set in the schema
|
|
||||||
exclude_none: Exclude fields with None value
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated model instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NotFoundError: If no record found
|
|
||||||
"""
|
|
||||||
async with get_transaction(session):
|
|
||||||
db_model = await cls.get(session=session, filters=filters)
|
|
||||||
values = obj.model_dump(
|
|
||||||
exclude_unset=exclude_unset, exclude_none=exclude_none
|
|
||||||
)
|
|
||||||
for key, value in values.items():
|
|
||||||
setattr(db_model, key, value)
|
|
||||||
await session.refresh(db_model)
|
|
||||||
return db_model
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def upsert(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
obj: BaseModel,
|
|
||||||
index_elements: list[str],
|
|
||||||
*,
|
|
||||||
set_: BaseModel | None = None,
|
|
||||||
where: WhereHavingRole | None = None,
|
|
||||||
) -> ModelType | None:
|
|
||||||
"""Create or update a record (PostgreSQL only).
|
|
||||||
|
|
||||||
Uses INSERT ... ON CONFLICT for atomic upsert.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
obj: Pydantic model with data
|
|
||||||
index_elements: Columns for ON CONFLICT (unique constraint)
|
|
||||||
set_: Pydantic model for ON CONFLICT DO UPDATE SET
|
|
||||||
where: WHERE clause for ON CONFLICT DO UPDATE
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Model instance
|
|
||||||
"""
|
|
||||||
async with get_transaction(session):
|
|
||||||
values = obj.model_dump(exclude_unset=True)
|
|
||||||
q = insert(cls.model).values(**values)
|
|
||||||
if set_:
|
|
||||||
q = q.on_conflict_do_update(
|
|
||||||
index_elements=index_elements,
|
|
||||||
set_=set_.model_dump(exclude_unset=True),
|
|
||||||
where=where,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
q = q.on_conflict_do_nothing(index_elements=index_elements)
|
|
||||||
q = q.returning(cls.model)
|
|
||||||
result = await session.execute(q)
|
|
||||||
try:
|
|
||||||
db_model = result.unique().scalar_one()
|
|
||||||
except NoResultFound:
|
|
||||||
db_model = await cls.first(
|
|
||||||
session=session,
|
|
||||||
filters=[getattr(cls.model, k) == v for k, v in values.items()],
|
|
||||||
)
|
|
||||||
return cast(ModelType | None, db_model)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def delete(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
filters: list[Any],
|
|
||||||
) -> bool:
|
|
||||||
"""Delete records from the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
filters: List of SQLAlchemy filter conditions
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if deletion was executed
|
|
||||||
"""
|
|
||||||
async with get_transaction(session):
|
|
||||||
q = sql_delete(cls.model).where(and_(*filters))
|
|
||||||
await session.execute(q)
|
|
||||||
return True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def count(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
filters: list[Any] | None = None,
|
|
||||||
) -> int:
|
|
||||||
"""Count records matching the filters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
filters: List of SQLAlchemy filter conditions
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of matching records
|
|
||||||
"""
|
|
||||||
q = select(func.count()).select_from(cls.model)
|
|
||||||
if filters:
|
|
||||||
q = q.where(and_(*filters))
|
|
||||||
result = await session.execute(q)
|
|
||||||
return result.scalar_one()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def exists(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
filters: list[Any],
|
|
||||||
) -> bool:
|
|
||||||
"""Check if a record exists.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
filters: List of SQLAlchemy filter conditions
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if at least one record matches
|
|
||||||
"""
|
|
||||||
q = select(cls.model).where(and_(*filters)).exists().select()
|
|
||||||
result = await session.execute(q)
|
|
||||||
return bool(result.scalar())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def paginate(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
*,
|
|
||||||
filters: list[Any] | None = None,
|
|
||||||
load_options: list[Any] | None = None,
|
|
||||||
order_by: Any | None = None,
|
|
||||||
page: int = 1,
|
|
||||||
items_per_page: int = 20,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Get paginated results with metadata.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
filters: List of SQLAlchemy filter conditions
|
|
||||||
load_options: SQLAlchemy loader options
|
|
||||||
order_by: Column or list of columns to order by
|
|
||||||
page: Page number (1-indexed)
|
|
||||||
items_per_page: Number of items per page
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict with 'data' and 'pagination' keys
|
|
||||||
"""
|
|
||||||
filters = filters or []
|
|
||||||
offset = (page - 1) * items_per_page
|
|
||||||
|
|
||||||
items = await cls.get_multi(
|
|
||||||
session,
|
|
||||||
filters=filters,
|
|
||||||
load_options=load_options,
|
|
||||||
order_by=order_by,
|
|
||||||
limit=items_per_page,
|
|
||||||
offset=offset,
|
|
||||||
)
|
|
||||||
|
|
||||||
total_count = await cls.count(session, filters=filters)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"data": items,
|
|
||||||
"pagination": {
|
|
||||||
"total_count": total_count,
|
|
||||||
"items_per_page": items_per_page,
|
|
||||||
"page": page,
|
|
||||||
"has_more": page * items_per_page < total_count,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def CrudFactory(
|
|
||||||
model: type[ModelType],
|
|
||||||
) -> type[AsyncCrud[ModelType]]:
|
|
||||||
"""Create a CRUD class for a specific model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: SQLAlchemy model class
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AsyncCrud subclass bound to the model
|
|
||||||
|
|
||||||
Example:
|
|
||||||
from fastapi_toolsets.crud import CrudFactory
|
|
||||||
from myapp.models import User, Post
|
|
||||||
|
|
||||||
UserCrud = CrudFactory(User)
|
|
||||||
PostCrud = CrudFactory(Post)
|
|
||||||
|
|
||||||
# Usage
|
|
||||||
user = await UserCrud.get(session, [User.id == 1])
|
|
||||||
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
|
|
||||||
"""
|
|
||||||
cls = type(f"Async{model.__name__}Crud", (AsyncCrud,), {"model": model})
|
|
||||||
return cast(type[AsyncCrud[ModelType]], cls)
|
|
||||||
35
src/fastapi_toolsets/crud/__init__.py
Normal file
35
src/fastapi_toolsets/crud/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""Generic async CRUD operations for SQLAlchemy models."""
|
||||||
|
|
||||||
|
from ..exceptions import (
|
||||||
|
InvalidFacetFilterError,
|
||||||
|
InvalidSearchColumnError,
|
||||||
|
NoSearchableFieldsError,
|
||||||
|
UnsupportedFacetTypeError,
|
||||||
|
)
|
||||||
|
from ..schemas import PaginationType
|
||||||
|
from ..types import (
|
||||||
|
FacetFieldType,
|
||||||
|
JoinType,
|
||||||
|
M2MFieldType,
|
||||||
|
OrderByClause,
|
||||||
|
SearchFieldType,
|
||||||
|
)
|
||||||
|
from .factory import AsyncCrud, CrudFactory
|
||||||
|
from .search import SearchConfig, get_searchable_fields
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AsyncCrud",
|
||||||
|
"CrudFactory",
|
||||||
|
"FacetFieldType",
|
||||||
|
"get_searchable_fields",
|
||||||
|
"InvalidFacetFilterError",
|
||||||
|
"InvalidSearchColumnError",
|
||||||
|
"JoinType",
|
||||||
|
"M2MFieldType",
|
||||||
|
"NoSearchableFieldsError",
|
||||||
|
"OrderByClause",
|
||||||
|
"PaginationType",
|
||||||
|
"SearchConfig",
|
||||||
|
"SearchFieldType",
|
||||||
|
"UnsupportedFacetTypeError",
|
||||||
|
]
|
||||||
1796
src/fastapi_toolsets/crud/factory.py
Normal file
1796
src/fastapi_toolsets/crud/factory.py
Normal file
File diff suppressed because it is too large
Load Diff
358
src/fastapi_toolsets/crud/search.py
Normal file
358
src/fastapi_toolsets/crud/search.py
Normal file
@@ -0,0 +1,358 @@
|
|||||||
|
"""Search utilities for AsyncCrud."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import functools
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass, replace
|
||||||
|
from typing import TYPE_CHECKING, Any, Literal
|
||||||
|
|
||||||
|
from sqlalchemy import String, and_, func, or_, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||||
|
from sqlalchemy.types import (
|
||||||
|
ARRAY,
|
||||||
|
Boolean,
|
||||||
|
Date,
|
||||||
|
DateTime,
|
||||||
|
Enum,
|
||||||
|
Integer,
|
||||||
|
Numeric,
|
||||||
|
Time,
|
||||||
|
Uuid,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..exceptions import (
|
||||||
|
InvalidFacetFilterError,
|
||||||
|
InvalidSearchColumnError,
|
||||||
|
NoSearchableFieldsError,
|
||||||
|
UnsupportedFacetTypeError,
|
||||||
|
)
|
||||||
|
from ..types import FacetFieldType, SearchFieldType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sqlalchemy.sql.elements import ColumnElement
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SearchConfig:
|
||||||
|
"""Advanced search configuration.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
query: The search string
|
||||||
|
fields: Fields to search (columns or tuples for relationships)
|
||||||
|
case_sensitive: Case-sensitive search (default: False)
|
||||||
|
match_mode: "any" (OR) or "all" (AND) to combine fields
|
||||||
|
"""
|
||||||
|
|
||||||
|
query: str
|
||||||
|
fields: Sequence[SearchFieldType] | None = None
|
||||||
|
case_sensitive: bool = False
|
||||||
|
match_mode: Literal["any", "all"] = "any"
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=128)
|
||||||
|
def get_searchable_fields(
|
||||||
|
model: type[DeclarativeBase],
|
||||||
|
*,
|
||||||
|
include_relationships: bool = True,
|
||||||
|
max_depth: int = 1,
|
||||||
|
) -> list[SearchFieldType]:
|
||||||
|
"""Auto-detect String fields on a model and its relationships.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: SQLAlchemy model class
|
||||||
|
include_relationships: Include fields from many-to-one/one-to-one relationships
|
||||||
|
max_depth: Max depth for relationship traversal (default: 1)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of columns and tuples (relationship, column)
|
||||||
|
"""
|
||||||
|
fields: list[SearchFieldType] = []
|
||||||
|
mapper = model.__mapper__
|
||||||
|
|
||||||
|
# Direct String columns
|
||||||
|
for col in mapper.columns:
|
||||||
|
if isinstance(col.type, String):
|
||||||
|
fields.append(getattr(model, col.key))
|
||||||
|
|
||||||
|
# Relationships (one-to-one, many-to-one only)
|
||||||
|
if include_relationships and max_depth > 0:
|
||||||
|
for rel_name, rel_prop in mapper.relationships.items():
|
||||||
|
if rel_prop.uselist: # Skip collections (one-to-many, many-to-many)
|
||||||
|
continue
|
||||||
|
|
||||||
|
rel_attr = getattr(model, rel_name)
|
||||||
|
related_model = rel_prop.mapper.class_
|
||||||
|
|
||||||
|
for col in related_model.__mapper__.columns:
|
||||||
|
if isinstance(col.type, String):
|
||||||
|
fields.append((rel_attr, getattr(related_model, col.key)))
|
||||||
|
|
||||||
|
return fields
|
||||||
|
|
||||||
|
|
||||||
|
def build_search_filters(
|
||||||
|
model: type[DeclarativeBase],
|
||||||
|
search: str | SearchConfig,
|
||||||
|
search_fields: Sequence[SearchFieldType] | None = None,
|
||||||
|
default_fields: Sequence[SearchFieldType] | None = None,
|
||||||
|
search_column: str | None = None,
|
||||||
|
) -> tuple[list["ColumnElement[bool]"], list[InstrumentedAttribute[Any]]]:
|
||||||
|
"""Build SQLAlchemy filter conditions for search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: SQLAlchemy model class
|
||||||
|
search: Search string or SearchConfig
|
||||||
|
search_fields: Fields specified per-call (takes priority)
|
||||||
|
default_fields: Default fields (from ClassVar)
|
||||||
|
search_column: Optional key to narrow search to a single field.
|
||||||
|
Must match one of the resolved search field keys.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (filter_conditions, joins_needed)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NoSearchableFieldsError: If no searchable field has been configured
|
||||||
|
"""
|
||||||
|
# Normalize input
|
||||||
|
if isinstance(search, str):
|
||||||
|
config = SearchConfig(query=search, fields=search_fields)
|
||||||
|
else:
|
||||||
|
config = (
|
||||||
|
replace(search, fields=search_fields)
|
||||||
|
if search_fields is not None
|
||||||
|
else search
|
||||||
|
)
|
||||||
|
|
||||||
|
if not config.query or not config.query.strip():
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
# Determine which fields to search
|
||||||
|
fields = config.fields or default_fields or get_searchable_fields(model)
|
||||||
|
|
||||||
|
if not fields:
|
||||||
|
raise NoSearchableFieldsError(model)
|
||||||
|
|
||||||
|
# Narrow to a single column when search_column is specified
|
||||||
|
if search_column is not None:
|
||||||
|
keys = search_field_keys(fields)
|
||||||
|
index = {k: f for k, f in zip(keys, fields)}
|
||||||
|
if search_column not in index:
|
||||||
|
raise InvalidSearchColumnError(search_column, sorted(index))
|
||||||
|
fields = [index[search_column]]
|
||||||
|
|
||||||
|
query = config.query.strip()
|
||||||
|
filters: list[ColumnElement[bool]] = []
|
||||||
|
joins: list[InstrumentedAttribute[Any]] = []
|
||||||
|
added_joins: set[str] = set()
|
||||||
|
|
||||||
|
for field in fields:
|
||||||
|
if isinstance(field, tuple):
|
||||||
|
# Relationship: (User.role, Role.name) or deeper
|
||||||
|
for rel in field[:-1]:
|
||||||
|
rel_key = str(rel)
|
||||||
|
if rel_key not in added_joins:
|
||||||
|
joins.append(rel)
|
||||||
|
added_joins.add(rel_key)
|
||||||
|
column = field[-1]
|
||||||
|
else:
|
||||||
|
column = field
|
||||||
|
|
||||||
|
# Build the filter (cast to String for non-text columns)
|
||||||
|
column_as_string = column.cast(String)
|
||||||
|
if config.case_sensitive:
|
||||||
|
filters.append(column_as_string.like(f"%{query}%"))
|
||||||
|
else:
|
||||||
|
filters.append(column_as_string.ilike(f"%{query}%"))
|
||||||
|
|
||||||
|
if not filters: # pragma: no cover
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
# Combine based on match_mode
|
||||||
|
if config.match_mode == "any":
|
||||||
|
return [or_(*filters)], joins
|
||||||
|
else:
|
||||||
|
return filters, joins
|
||||||
|
|
||||||
|
|
||||||
|
def search_field_keys(fields: Sequence[SearchFieldType]) -> list[str]:
|
||||||
|
"""Return a human-readable key for each search field."""
|
||||||
|
return facet_keys(fields)
|
||||||
|
|
||||||
|
|
||||||
|
def facet_keys(facet_fields: Sequence[FacetFieldType]) -> list[str]:
|
||||||
|
"""Return a key for each facet field.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
facet_fields: Sequence of facet fields — either direct columns or
|
||||||
|
relationship tuples ``(rel, ..., column)``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of string keys, one per facet field, in the same order.
|
||||||
|
"""
|
||||||
|
keys: list[str] = []
|
||||||
|
for field in facet_fields:
|
||||||
|
if isinstance(field, tuple):
|
||||||
|
keys.append("__".join(el.key for el in field))
|
||||||
|
else:
|
||||||
|
keys.append(field.key)
|
||||||
|
return keys
|
||||||
|
|
||||||
|
|
||||||
|
async def build_facets(
|
||||||
|
session: "AsyncSession",
|
||||||
|
model: type[DeclarativeBase],
|
||||||
|
facet_fields: Sequence[FacetFieldType],
|
||||||
|
*,
|
||||||
|
base_filters: "list[ColumnElement[bool]] | None" = None,
|
||||||
|
base_joins: list[InstrumentedAttribute[Any]] | None = None,
|
||||||
|
) -> dict[str, list[Any]]:
|
||||||
|
"""Return distinct values for each facet field, respecting current filters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: DB async session
|
||||||
|
model: SQLAlchemy model class
|
||||||
|
facet_fields: Columns or relationship tuples to facet on
|
||||||
|
base_filters: Filter conditions already applied to the main query (search + caller filters)
|
||||||
|
base_joins: Relationship joins already applied to the main query
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping column key to sorted list of distinct non-None values
|
||||||
|
"""
|
||||||
|
existing_join_keys: set[str] = {str(j) for j in (base_joins or [])}
|
||||||
|
|
||||||
|
keys = facet_keys(facet_fields)
|
||||||
|
|
||||||
|
async def _query_facet(field: FacetFieldType, key: str) -> tuple[str, list[Any]]:
|
||||||
|
if isinstance(field, tuple):
|
||||||
|
# Relationship chain: (User.role, Role.name) — last element is the column
|
||||||
|
rels = field[:-1]
|
||||||
|
column = field[-1]
|
||||||
|
else:
|
||||||
|
rels = ()
|
||||||
|
column = field
|
||||||
|
|
||||||
|
col_type = column.property.columns[0].type
|
||||||
|
is_array = isinstance(col_type, ARRAY)
|
||||||
|
|
||||||
|
if is_array:
|
||||||
|
unnested = func.unnest(column).label(column.key)
|
||||||
|
q = select(unnested).select_from(model).distinct()
|
||||||
|
else:
|
||||||
|
q = select(column).select_from(model).distinct()
|
||||||
|
|
||||||
|
# Apply base joins (deduplicated) — needed here independently
|
||||||
|
seen_joins: set[str] = set()
|
||||||
|
for rel in base_joins or []:
|
||||||
|
rel_key = str(rel)
|
||||||
|
if rel_key not in seen_joins:
|
||||||
|
seen_joins.add(rel_key)
|
||||||
|
q = q.outerjoin(rel)
|
||||||
|
|
||||||
|
# Add any extra joins required by this facet field that aren't already applied
|
||||||
|
for rel in rels:
|
||||||
|
rel_key = str(rel)
|
||||||
|
if rel_key not in existing_join_keys and rel_key not in seen_joins:
|
||||||
|
seen_joins.add(rel_key)
|
||||||
|
q = q.outerjoin(rel)
|
||||||
|
|
||||||
|
if base_filters:
|
||||||
|
q = q.where(and_(*base_filters))
|
||||||
|
|
||||||
|
if is_array:
|
||||||
|
q = q.order_by(unnested)
|
||||||
|
else:
|
||||||
|
q = q.order_by(column)
|
||||||
|
result = await session.execute(q)
|
||||||
|
values = [row[0] for row in result.all() if row[0] is not None]
|
||||||
|
return key, values
|
||||||
|
|
||||||
|
pairs = await asyncio.gather(
|
||||||
|
*[_query_facet(f, k) for f, k in zip(facet_fields, keys)]
|
||||||
|
)
|
||||||
|
return dict(pairs)
|
||||||
|
|
||||||
|
|
||||||
|
_EQUALITY_TYPES = (String, Integer, Numeric, Date, DateTime, Time, Enum, Uuid)
|
||||||
|
"""Column types that support equality / IN filtering in build_filter_by."""
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_bool(value: Any) -> bool:
|
||||||
|
"""Coerce a string value to a Python bool for Boolean column filtering."""
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
if isinstance(value, str):
|
||||||
|
if value.lower() == "true":
|
||||||
|
return True
|
||||||
|
if value.lower() == "false":
|
||||||
|
return False
|
||||||
|
raise ValueError(f"Cannot coerce {value!r} to bool")
|
||||||
|
|
||||||
|
|
||||||
|
def build_filter_by(
|
||||||
|
filter_by: dict[str, Any],
|
||||||
|
facet_fields: Sequence[FacetFieldType],
|
||||||
|
) -> tuple["list[ColumnElement[bool]]", list[InstrumentedAttribute[Any]]]:
|
||||||
|
"""Translate a {column_key: value} dict into SQLAlchemy filter conditions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter_by: Mapping of column key to scalar value or list of values
|
||||||
|
facet_fields: Declared facet fields to validate keys against
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (filter_conditions, joins_needed)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidFacetFilterError: If a key in filter_by is not a declared facet field
|
||||||
|
"""
|
||||||
|
index: dict[
|
||||||
|
str, tuple[InstrumentedAttribute[Any], list[InstrumentedAttribute[Any]]]
|
||||||
|
] = {}
|
||||||
|
for key, field in zip(facet_keys(facet_fields), facet_fields):
|
||||||
|
if isinstance(field, tuple):
|
||||||
|
rels = list(field[:-1])
|
||||||
|
column = field[-1]
|
||||||
|
else:
|
||||||
|
rels = []
|
||||||
|
column = field
|
||||||
|
index[key] = (column, rels)
|
||||||
|
|
||||||
|
valid_keys = set(index)
|
||||||
|
filters: list[ColumnElement[bool]] = []
|
||||||
|
joins: list[InstrumentedAttribute[Any]] = []
|
||||||
|
added_join_keys: set[str] = set()
|
||||||
|
|
||||||
|
for key, value in filter_by.items():
|
||||||
|
if key not in index:
|
||||||
|
raise InvalidFacetFilterError(key, valid_keys)
|
||||||
|
|
||||||
|
column, rels = index[key]
|
||||||
|
|
||||||
|
for rel in rels:
|
||||||
|
rel_key = str(rel)
|
||||||
|
if rel_key not in added_join_keys:
|
||||||
|
joins.append(rel)
|
||||||
|
added_join_keys.add(rel_key)
|
||||||
|
|
||||||
|
col_type = column.property.columns[0].type
|
||||||
|
if isinstance(col_type, Boolean):
|
||||||
|
coerce = _coerce_bool
|
||||||
|
if isinstance(value, list):
|
||||||
|
filters.append(column.in_([coerce(v) for v in value]))
|
||||||
|
else:
|
||||||
|
filters.append(column == coerce(value))
|
||||||
|
elif isinstance(col_type, ARRAY):
|
||||||
|
if isinstance(value, list):
|
||||||
|
filters.append(column.overlap(value))
|
||||||
|
else:
|
||||||
|
filters.append(column.any(value))
|
||||||
|
elif isinstance(col_type, _EQUALITY_TYPES):
|
||||||
|
if isinstance(value, list):
|
||||||
|
filters.append(column.in_(value))
|
||||||
|
else:
|
||||||
|
filters.append(column == value)
|
||||||
|
else:
|
||||||
|
raise UnsupportedFacetTypeError(key, type(col_type).__name__)
|
||||||
|
|
||||||
|
return filters, joins
|
||||||
@@ -1,25 +1,35 @@
|
|||||||
"""Database utilities: sessions, transactions, and locks."""
|
"""Database utilities: sessions, transactions, and locks."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from collections.abc import AsyncGenerator, Callable
|
from collections.abc import AsyncGenerator, Callable
|
||||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from .exceptions import NotFoundError
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LockMode",
|
"LockMode",
|
||||||
|
"cleanup_tables",
|
||||||
|
"create_database",
|
||||||
"create_db_context",
|
"create_db_context",
|
||||||
"create_db_dependency",
|
"create_db_dependency",
|
||||||
"lock_tables",
|
|
||||||
"get_transaction",
|
"get_transaction",
|
||||||
|
"lock_tables",
|
||||||
|
"wait_for_row_change",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
_SessionT = TypeVar("_SessionT", bound=AsyncSession)
|
||||||
|
|
||||||
|
|
||||||
def create_db_dependency(
|
def create_db_dependency(
|
||||||
session_maker: async_sessionmaker[AsyncSession],
|
session_maker: async_sessionmaker[_SessionT],
|
||||||
) -> Callable[[], AsyncGenerator[AsyncSession, None]]:
|
) -> Callable[[], AsyncGenerator[_SessionT, None]]:
|
||||||
"""Create a FastAPI dependency for database sessions.
|
"""Create a FastAPI dependency for database sessions.
|
||||||
|
|
||||||
Creates a dependency function that yields a session and auto-commits
|
Creates a dependency function that yields a session and auto-commits
|
||||||
@@ -32,6 +42,7 @@ def create_db_dependency(
|
|||||||
An async generator function usable with FastAPI's Depends()
|
An async generator function usable with FastAPI's Depends()
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||||
from fastapi_toolsets.db import create_db_dependency
|
from fastapi_toolsets.db import create_db_dependency
|
||||||
@@ -43,10 +54,12 @@ def create_db_dependency(
|
|||||||
@app.get("/users")
|
@app.get("/users")
|
||||||
async def list_users(session: AsyncSession = Depends(get_db)):
|
async def list_users(session: AsyncSession = Depends(get_db)):
|
||||||
...
|
...
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
async def get_db() -> AsyncGenerator[_SessionT, None]:
|
||||||
async with session_maker() as session:
|
async with session_maker() as session:
|
||||||
|
await session.connection()
|
||||||
yield session
|
yield session
|
||||||
if session.in_transaction():
|
if session.in_transaction():
|
||||||
await session.commit()
|
await session.commit()
|
||||||
@@ -55,8 +68,8 @@ def create_db_dependency(
|
|||||||
|
|
||||||
|
|
||||||
def create_db_context(
|
def create_db_context(
|
||||||
session_maker: async_sessionmaker[AsyncSession],
|
session_maker: async_sessionmaker[_SessionT],
|
||||||
) -> Callable[[], AbstractAsyncContextManager[AsyncSession]]:
|
) -> Callable[[], AbstractAsyncContextManager[_SessionT]]:
|
||||||
"""Create a context manager for database sessions.
|
"""Create a context manager for database sessions.
|
||||||
|
|
||||||
Creates a context manager for use outside of FastAPI request handlers,
|
Creates a context manager for use outside of FastAPI request handlers,
|
||||||
@@ -69,6 +82,7 @@ def create_db_context(
|
|||||||
An async context manager function
|
An async context manager function
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||||
from fastapi_toolsets.db import create_db_context
|
from fastapi_toolsets.db import create_db_context
|
||||||
|
|
||||||
@@ -80,6 +94,7 @@ def create_db_context(
|
|||||||
async with get_db_context() as session:
|
async with get_db_context() as session:
|
||||||
user = await UserCrud.get(session, [User.id == 1])
|
user = await UserCrud.get(session, [User.id == 1])
|
||||||
...
|
...
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
get_db = create_db_dependency(session_maker)
|
get_db = create_db_dependency(session_maker)
|
||||||
return asynccontextmanager(get_db)
|
return asynccontextmanager(get_db)
|
||||||
@@ -101,9 +116,11 @@ async def get_transaction(
|
|||||||
The session within the transaction context
|
The session within the transaction context
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
async with get_transaction(session):
|
async with get_transaction(session):
|
||||||
session.add(model)
|
session.add(model)
|
||||||
# Auto-commits on exit, rolls back on exception
|
# Auto-commits on exit, rolls back on exception
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
if session.in_transaction():
|
if session.in_transaction():
|
||||||
async with session.begin_nested():
|
async with session.begin_nested():
|
||||||
@@ -155,6 +172,7 @@ async def lock_tables(
|
|||||||
SQLAlchemyError: If lock cannot be acquired within timeout
|
SQLAlchemyError: If lock cannot be acquired within timeout
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
from fastapi_toolsets.db import lock_tables, LockMode
|
from fastapi_toolsets.db import lock_tables, LockMode
|
||||||
|
|
||||||
async with lock_tables(session, [User, Account]):
|
async with lock_tables(session, [User, Account]):
|
||||||
@@ -166,6 +184,7 @@ async def lock_tables(
|
|||||||
async with lock_tables(session, [Order], mode=LockMode.EXCLUSIVE):
|
async with lock_tables(session, [Order], mode=LockMode.EXCLUSIVE):
|
||||||
# Exclusive lock - no other transactions can access
|
# Exclusive lock - no other transactions can access
|
||||||
await process_order(session, order_id)
|
await process_order(session, order_id)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
table_names = ",".join(table.__tablename__ for table in tables)
|
table_names = ",".join(table.__tablename__ for table in tables)
|
||||||
|
|
||||||
@@ -173,3 +192,150 @@ async def lock_tables(
|
|||||||
await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
|
await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
|
||||||
await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
|
await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
async def create_database(
|
||||||
|
db_name: str,
|
||||||
|
*,
|
||||||
|
server_url: str,
|
||||||
|
) -> None:
|
||||||
|
"""Create a database.
|
||||||
|
|
||||||
|
Connects to *server_url* using ``AUTOCOMMIT`` isolation and issues a
|
||||||
|
``CREATE DATABASE`` statement for *db_name*.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_name: Name of the database to create.
|
||||||
|
server_url: URL used for server-level DDL (must point to an existing
|
||||||
|
database on the same server).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import create_database
|
||||||
|
|
||||||
|
SERVER_URL = "postgresql+asyncpg://postgres:postgres@localhost/postgres"
|
||||||
|
await create_database("myapp_test", server_url=SERVER_URL)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
engine = create_async_engine(server_url, isolation_level="AUTOCOMMIT")
|
||||||
|
try:
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
await conn.execute(text(f"CREATE DATABASE {db_name}"))
|
||||||
|
finally:
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
async def cleanup_tables(
|
||||||
|
session: AsyncSession,
|
||||||
|
base: type[DeclarativeBase],
|
||||||
|
) -> None:
|
||||||
|
"""Truncate all tables for fast between-test cleanup.
|
||||||
|
|
||||||
|
Executes a single ``TRUNCATE … RESTART IDENTITY CASCADE`` statement
|
||||||
|
across every table in *base*'s metadata, which is significantly faster
|
||||||
|
than dropping and re-creating tables between tests.
|
||||||
|
|
||||||
|
This is a no-op when the metadata contains no tables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: An active async database session.
|
||||||
|
base: SQLAlchemy DeclarativeBase class containing model metadata.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session(worker_db_url):
|
||||||
|
async with create_db_session(worker_db_url, Base) as session:
|
||||||
|
yield session
|
||||||
|
await cleanup_tables(session, Base)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
tables = base.metadata.sorted_tables
|
||||||
|
if not tables:
|
||||||
|
return
|
||||||
|
|
||||||
|
table_names = ", ".join(f'"{t.name}"' for t in tables)
|
||||||
|
await session.execute(text(f"TRUNCATE {table_names} RESTART IDENTITY CASCADE"))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
_M = TypeVar("_M", bound=DeclarativeBase)
|
||||||
|
|
||||||
|
|
||||||
|
async def wait_for_row_change(
|
||||||
|
session: AsyncSession,
|
||||||
|
model: type[_M],
|
||||||
|
pk_value: Any,
|
||||||
|
*,
|
||||||
|
columns: list[str] | None = None,
|
||||||
|
interval: float = 0.5,
|
||||||
|
timeout: float | None = None,
|
||||||
|
) -> _M:
|
||||||
|
"""Poll a database row until a change is detected.
|
||||||
|
|
||||||
|
Queries the row every ``interval`` seconds and returns the model instance
|
||||||
|
once a change is detected in any column (or only the specified ``columns``).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: AsyncSession instance
|
||||||
|
model: SQLAlchemy model class
|
||||||
|
pk_value: Primary key value of the row to watch
|
||||||
|
columns: Optional list of column names to watch. If None, all columns
|
||||||
|
are watched.
|
||||||
|
interval: Polling interval in seconds (default: 0.5)
|
||||||
|
timeout: Maximum time to wait in seconds. None means wait forever.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The refreshed model instance with updated values
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If the row does not exist or is deleted during polling
|
||||||
|
TimeoutError: If timeout expires before a change is detected
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import wait_for_row_change
|
||||||
|
|
||||||
|
# Wait for any column to change
|
||||||
|
updated = await wait_for_row_change(session, User, user_id)
|
||||||
|
|
||||||
|
# Watch specific columns with a timeout
|
||||||
|
updated = await wait_for_row_change(
|
||||||
|
session, User, user_id,
|
||||||
|
columns=["status", "email"],
|
||||||
|
interval=1.0,
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
instance = await session.get(model, pk_value)
|
||||||
|
if instance is None:
|
||||||
|
raise NotFoundError(f"{model.__name__} with pk={pk_value!r} not found")
|
||||||
|
|
||||||
|
if columns is not None:
|
||||||
|
watch_cols = columns
|
||||||
|
else:
|
||||||
|
watch_cols = [attr.key for attr in model.__mapper__.column_attrs]
|
||||||
|
|
||||||
|
initial = {col: getattr(instance, col) for col in watch_cols}
|
||||||
|
|
||||||
|
elapsed = 0.0
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
elapsed += interval
|
||||||
|
|
||||||
|
if timeout is not None and elapsed >= timeout:
|
||||||
|
raise TimeoutError(
|
||||||
|
f"No change detected on {model.__name__} "
|
||||||
|
f"with pk={pk_value!r} within {timeout}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
session.expunge(instance)
|
||||||
|
instance = await session.get(model, pk_value)
|
||||||
|
|
||||||
|
if instance is None:
|
||||||
|
raise NotFoundError(f"{model.__name__} with pk={pk_value!r} was deleted")
|
||||||
|
|
||||||
|
current = {col: getattr(instance, col) for col in watch_cols}
|
||||||
|
if current != initial:
|
||||||
|
return instance
|
||||||
|
|||||||
155
src/fastapi_toolsets/dependencies.py
Normal file
155
src/fastapi_toolsets/dependencies.py
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
"""Dependency factories for FastAPI routes."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import typing
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
from fastapi.params import Depends as DependsClass
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from .crud import CrudFactory
|
||||||
|
from .types import ModelType, SessionDependency
|
||||||
|
|
||||||
|
__all__ = ["BodyDependency", "PathDependency"]
|
||||||
|
|
||||||
|
|
||||||
|
def _unwrap_session_dep(session_dep: SessionDependency) -> Callable[..., Any]:
|
||||||
|
"""Extract the plain callable from ``Annotated[AsyncSession, Depends(fn)]`` if needed."""
|
||||||
|
if typing.get_origin(session_dep) is typing.Annotated:
|
||||||
|
for arg in typing.get_args(session_dep)[1:]:
|
||||||
|
if isinstance(arg, DependsClass):
|
||||||
|
return arg.dependency
|
||||||
|
return session_dep
|
||||||
|
|
||||||
|
|
||||||
|
def PathDependency(
|
||||||
|
model: type[ModelType],
|
||||||
|
field: Any,
|
||||||
|
*,
|
||||||
|
session_dep: SessionDependency,
|
||||||
|
param_name: str | None = None,
|
||||||
|
) -> ModelType:
|
||||||
|
"""Create a dependency that fetches a DB object from a path parameter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: SQLAlchemy model class
|
||||||
|
field: Model field to filter by (e.g., User.id)
|
||||||
|
session_dep: Session dependency function (e.g., get_db)
|
||||||
|
param_name: Path parameter name (defaults to model_field, e.g., user_id)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Depends() instance that resolves to the model instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If no matching record is found
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
UserDep = PathDependency(User, User.id, session_dep=get_db)
|
||||||
|
|
||||||
|
@router.get("/user/{id}")
|
||||||
|
async def get(
|
||||||
|
user: User = UserDep,
|
||||||
|
): ...
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
session_callable = _unwrap_session_dep(session_dep)
|
||||||
|
crud = CrudFactory(model)
|
||||||
|
name = (
|
||||||
|
param_name
|
||||||
|
if param_name is not None
|
||||||
|
else "{}_{}".format(model.__name__.lower(), field.key)
|
||||||
|
)
|
||||||
|
python_type = field.type.python_type
|
||||||
|
|
||||||
|
async def dependency(
|
||||||
|
session: AsyncSession = Depends(session_callable), **kwargs: Any
|
||||||
|
) -> ModelType:
|
||||||
|
value = kwargs[name]
|
||||||
|
return await crud.get(session, filters=[field == value])
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
dependency,
|
||||||
|
"__signature__",
|
||||||
|
inspect.Signature(
|
||||||
|
parameters=[
|
||||||
|
inspect.Parameter(
|
||||||
|
name, inspect.Parameter.KEYWORD_ONLY, annotation=python_type
|
||||||
|
),
|
||||||
|
inspect.Parameter(
|
||||||
|
"session",
|
||||||
|
inspect.Parameter.KEYWORD_ONLY,
|
||||||
|
annotation=AsyncSession,
|
||||||
|
default=Depends(session_callable),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return cast(ModelType, Depends(cast(Callable[..., ModelType], dependency)))
|
||||||
|
|
||||||
|
|
||||||
|
def BodyDependency(
|
||||||
|
model: type[ModelType],
|
||||||
|
field: Any,
|
||||||
|
*,
|
||||||
|
session_dep: SessionDependency,
|
||||||
|
body_field: str,
|
||||||
|
) -> ModelType:
|
||||||
|
"""Create a dependency that fetches a DB object from a body field.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: SQLAlchemy model class
|
||||||
|
field: Model field to filter by (e.g., User.id)
|
||||||
|
session_dep: Session dependency function (e.g., get_db)
|
||||||
|
body_field: Name of the field in the request body
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Depends() instance that resolves to the model instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If no matching record is found
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
UserDep = BodyDependency(
|
||||||
|
User, User.ctfd_id, session_dep=get_db, body_field="user_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post("/assign")
|
||||||
|
async def assign(
|
||||||
|
user: User = UserDep,
|
||||||
|
): ...
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
session_callable = _unwrap_session_dep(session_dep)
|
||||||
|
crud = CrudFactory(model)
|
||||||
|
python_type = field.type.python_type
|
||||||
|
|
||||||
|
async def dependency(
|
||||||
|
session: AsyncSession = Depends(session_callable), **kwargs: Any
|
||||||
|
) -> ModelType:
|
||||||
|
value = kwargs[body_field]
|
||||||
|
return await crud.get(session, filters=[field == value])
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
dependency,
|
||||||
|
"__signature__",
|
||||||
|
inspect.Signature(
|
||||||
|
parameters=[
|
||||||
|
inspect.Parameter(
|
||||||
|
body_field, inspect.Parameter.KEYWORD_ONLY, annotation=python_type
|
||||||
|
),
|
||||||
|
inspect.Parameter(
|
||||||
|
"session",
|
||||||
|
inspect.Parameter.KEYWORD_ONLY,
|
||||||
|
annotation=AsyncSession,
|
||||||
|
default=Depends(session_callable),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return cast(ModelType, Depends(cast(Callable[..., ModelType], dependency)))
|
||||||
@@ -1,19 +1,33 @@
|
|||||||
|
"""Standardized API exceptions and error response handlers."""
|
||||||
|
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
|
ApiError,
|
||||||
ApiException,
|
ApiException,
|
||||||
ConflictError,
|
ConflictError,
|
||||||
ForbiddenError,
|
ForbiddenError,
|
||||||
|
InvalidFacetFilterError,
|
||||||
|
InvalidOrderFieldError,
|
||||||
|
InvalidSearchColumnError,
|
||||||
|
NoSearchableFieldsError,
|
||||||
NotFoundError,
|
NotFoundError,
|
||||||
UnauthorizedError,
|
UnauthorizedError,
|
||||||
|
UnsupportedFacetTypeError,
|
||||||
generate_error_responses,
|
generate_error_responses,
|
||||||
)
|
)
|
||||||
from .handler import init_exceptions_handlers
|
from .handler import init_exceptions_handlers
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"init_exceptions_handlers",
|
"ApiError",
|
||||||
"generate_error_responses",
|
|
||||||
"ApiException",
|
"ApiException",
|
||||||
"ConflictError",
|
"ConflictError",
|
||||||
"ForbiddenError",
|
"ForbiddenError",
|
||||||
|
"generate_error_responses",
|
||||||
|
"init_exceptions_handlers",
|
||||||
|
"InvalidFacetFilterError",
|
||||||
|
"InvalidOrderFieldError",
|
||||||
|
"InvalidSearchColumnError",
|
||||||
|
"NoSearchableFieldsError",
|
||||||
"NotFoundError",
|
"NotFoundError",
|
||||||
"UnauthorizedError",
|
"UnauthorizedError",
|
||||||
|
"UnsupportedFacetTypeError",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -6,30 +6,46 @@ from ..schemas import ApiError, ErrorResponse, ResponseStatus
|
|||||||
|
|
||||||
|
|
||||||
class ApiException(Exception):
|
class ApiException(Exception):
|
||||||
"""Base exception for API errors with structured response.
|
"""Base exception for API errors with structured response."""
|
||||||
|
|
||||||
Subclass this to create custom API exceptions with consistent error format.
|
|
||||||
The exception handler will use api_error to generate the response.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
class CustomError(ApiException):
|
|
||||||
api_error = ApiError(
|
|
||||||
code=400,
|
|
||||||
msg="Bad Request",
|
|
||||||
desc="The request was invalid.",
|
|
||||||
err_code="CUSTOM-400",
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
api_error: ClassVar[ApiError]
|
api_error: ClassVar[ApiError]
|
||||||
|
|
||||||
def __init__(self, detail: str | None = None):
|
def __init_subclass__(cls, abstract: bool = False, **kwargs: Any) -> None:
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
if not abstract and not hasattr(cls, "api_error"):
|
||||||
|
raise TypeError(
|
||||||
|
f"{cls.__name__} must define an 'api_error' class attribute. "
|
||||||
|
"Pass abstract=True when creating intermediate base classes."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
detail: str | None = None,
|
||||||
|
*,
|
||||||
|
desc: str | None = None,
|
||||||
|
data: Any = None,
|
||||||
|
) -> None:
|
||||||
"""Initialize the exception.
|
"""Initialize the exception.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
detail: Optional override for the error message
|
detail: Optional human-readable message
|
||||||
|
desc: Optional per-instance override for the ``description`` field
|
||||||
|
in the HTTP response body.
|
||||||
|
data: Optional per-instance override for the ``data`` field in the
|
||||||
|
HTTP response body.
|
||||||
"""
|
"""
|
||||||
super().__init__(detail or self.api_error.msg)
|
updates: dict[str, Any] = {}
|
||||||
|
if detail is not None:
|
||||||
|
updates["msg"] = detail
|
||||||
|
if desc is not None:
|
||||||
|
updates["desc"] = desc
|
||||||
|
if data is not None:
|
||||||
|
updates["data"] = data
|
||||||
|
if updates:
|
||||||
|
object.__setattr__(
|
||||||
|
self, "api_error", self.__class__.api_error.model_copy(update=updates)
|
||||||
|
)
|
||||||
|
super().__init__(self.api_error.msg)
|
||||||
|
|
||||||
|
|
||||||
class UnauthorizedError(ApiException):
|
class UnauthorizedError(ApiException):
|
||||||
@@ -76,90 +92,175 @@ class ConflictError(ApiException):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InsufficientRolesError(ForbiddenError):
|
class NoSearchableFieldsError(ApiException):
|
||||||
"""User does not have the required roles."""
|
"""Raised when search is requested but no searchable fields are available."""
|
||||||
|
|
||||||
api_error = ApiError(
|
api_error = ApiError(
|
||||||
code=403,
|
code=400,
|
||||||
msg="Insufficient Roles",
|
msg="No Searchable Fields",
|
||||||
desc="You do not have the required roles to access this resource.",
|
desc="No searchable fields configured for this resource.",
|
||||||
err_code="RBAC-403",
|
err_code="SEARCH-400",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, required_roles: list[str], user_roles: set[str] | None = None):
|
def __init__(self, model: type) -> None:
|
||||||
self.required_roles = required_roles
|
"""Initialize the exception.
|
||||||
self.user_roles = user_roles
|
|
||||||
|
|
||||||
desc = f"Required roles: {', '.join(required_roles)}"
|
Args:
|
||||||
if user_roles is not None:
|
model: The model class that has no searchable fields configured.
|
||||||
desc += f". User has: {', '.join(user_roles) if user_roles else 'no roles'}"
|
"""
|
||||||
|
self.model = model
|
||||||
super().__init__(desc)
|
super().__init__(
|
||||||
|
desc=(
|
||||||
|
f"No searchable fields found for model '{model.__name__}'. "
|
||||||
|
"Provide 'search_fields' parameter or set 'searchable_fields' on the CRUD class."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class UserNotFoundError(NotFoundError):
|
class InvalidFacetFilterError(ApiException):
|
||||||
"""User was not found."""
|
"""Raised when filter_by contains a key not declared in facet_fields."""
|
||||||
|
|
||||||
api_error = ApiError(
|
api_error = ApiError(
|
||||||
code=404,
|
code=400,
|
||||||
msg="User Not Found",
|
msg="Invalid Facet Filter",
|
||||||
desc="The requested user was not found.",
|
desc="One or more filter_by keys are not declared as facet fields.",
|
||||||
err_code="USER-404",
|
err_code="FACET-400",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __init__(self, key: str, valid_keys: set[str]) -> None:
|
||||||
|
"""Initialize the exception.
|
||||||
|
|
||||||
class RoleNotFoundError(NotFoundError):
|
Args:
|
||||||
"""Role was not found."""
|
key: The unknown filter key provided by the caller.
|
||||||
|
valid_keys: Set of valid keys derived from the declared facet_fields.
|
||||||
|
"""
|
||||||
|
self.key = key
|
||||||
|
self.valid_keys = valid_keys
|
||||||
|
super().__init__(
|
||||||
|
desc=(
|
||||||
|
f"'{key}' is not a declared facet field. "
|
||||||
|
f"Valid keys: {sorted(valid_keys) or 'none — set facet_fields on the CRUD class'}."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedFacetTypeError(ApiException):
|
||||||
|
"""Raised when a facet field has a column type not supported by filter_by."""
|
||||||
|
|
||||||
api_error = ApiError(
|
api_error = ApiError(
|
||||||
code=404,
|
code=400,
|
||||||
msg="Role Not Found",
|
msg="Unsupported Facet Type",
|
||||||
desc="The requested role was not found.",
|
desc="The column type is not supported for facet filtering.",
|
||||||
err_code="ROLE-404",
|
err_code="FACET-TYPE-400",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __init__(self, key: str, col_type: str) -> None:
|
||||||
|
"""Initialize the exception.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The facet field key.
|
||||||
|
col_type: The unsupported column type name.
|
||||||
|
"""
|
||||||
|
self.key = key
|
||||||
|
self.col_type = col_type
|
||||||
|
super().__init__(
|
||||||
|
desc=(
|
||||||
|
f"Facet field '{key}' has unsupported column type '{col_type}'. "
|
||||||
|
f"Supported types: String, Integer, Numeric, Boolean, "
|
||||||
|
f"Date, DateTime, Time, Enum, Uuid, ARRAY."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidSearchColumnError(ApiException):
|
||||||
|
"""Raised when search_column is not one of the configured searchable fields."""
|
||||||
|
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400,
|
||||||
|
msg="Invalid Search Column",
|
||||||
|
desc="The requested search column is not a configured searchable field.",
|
||||||
|
err_code="SEARCH-COL-400",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, column: str, valid_columns: list[str]) -> None:
|
||||||
|
"""Initialize the exception.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column: The unknown search column provided by the caller.
|
||||||
|
valid_columns: List of valid search column keys.
|
||||||
|
"""
|
||||||
|
self.column = column
|
||||||
|
self.valid_columns = valid_columns
|
||||||
|
super().__init__(
|
||||||
|
desc=(
|
||||||
|
f"'{column}' is not a searchable column. "
|
||||||
|
f"Valid columns: {valid_columns}."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidOrderFieldError(ApiException):
|
||||||
|
"""Raised when order_by contains a field not in the allowed order fields."""
|
||||||
|
|
||||||
|
api_error = ApiError(
|
||||||
|
code=422,
|
||||||
|
msg="Invalid Order Field",
|
||||||
|
desc="The requested order field is not allowed for this resource.",
|
||||||
|
err_code="SORT-422",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, field: str, valid_fields: list[str]) -> None:
|
||||||
|
"""Initialize the exception.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field: The unknown order field provided by the caller.
|
||||||
|
valid_fields: List of valid field names.
|
||||||
|
"""
|
||||||
|
self.field = field
|
||||||
|
self.valid_fields = valid_fields
|
||||||
|
super().__init__(
|
||||||
|
desc=f"'{field}' is not an allowed order field. Valid fields: {valid_fields}."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_error_responses(
|
def generate_error_responses(
|
||||||
*errors: type[ApiException],
|
*errors: type[ApiException],
|
||||||
) -> dict[int | str, dict[str, Any]]:
|
) -> dict[int | str, dict[str, Any]]:
|
||||||
"""Generate OpenAPI response documentation for exceptions.
|
"""Generate OpenAPI response documentation for exceptions.
|
||||||
|
|
||||||
Use this to document possible error responses for an endpoint.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
*errors: Exception classes that inherit from ApiException
|
*errors: Exception classes that inherit from ApiException.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict suitable for FastAPI's responses parameter
|
Dict suitable for FastAPI's ``responses`` parameter.
|
||||||
|
|
||||||
Example:
|
|
||||||
from fastapi_toolsets.exceptions import generate_error_responses, UnauthorizedError, ForbiddenError
|
|
||||||
|
|
||||||
@app.get(
|
|
||||||
"/admin",
|
|
||||||
responses=generate_error_responses(UnauthorizedError, ForbiddenError)
|
|
||||||
)
|
|
||||||
async def admin_endpoint():
|
|
||||||
...
|
|
||||||
"""
|
"""
|
||||||
responses: dict[int | str, dict[str, Any]] = {}
|
responses: dict[int | str, dict[str, Any]] = {}
|
||||||
|
|
||||||
for error in errors:
|
for error in errors:
|
||||||
api_error = error.api_error
|
api_error = error.api_error
|
||||||
|
code = api_error.code
|
||||||
|
|
||||||
responses[api_error.code] = {
|
if code not in responses:
|
||||||
"model": ErrorResponse,
|
responses[code] = {
|
||||||
"description": api_error.msg,
|
"model": ErrorResponse,
|
||||||
"content": {
|
"description": api_error.msg,
|
||||||
"application/json": {
|
"content": {
|
||||||
"example": {
|
"application/json": {
|
||||||
"data": None,
|
"examples": {},
|
||||||
"status": ResponseStatus.FAIL.value,
|
|
||||||
"message": api_error.msg,
|
|
||||||
"description": api_error.desc,
|
|
||||||
"error_code": api_error.err_code,
|
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
responses[code]["content"]["application/json"]["examples"][
|
||||||
|
api_error.err_code
|
||||||
|
] = {
|
||||||
|
"summary": api_error.msg,
|
||||||
|
"value": {
|
||||||
|
"data": api_error.data,
|
||||||
|
"status": ResponseStatus.FAIL.value,
|
||||||
|
"message": api_error.msg,
|
||||||
|
"description": api_error.desc,
|
||||||
|
"error_code": api_error.err_code,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,50 +1,69 @@
|
|||||||
"""Exception handlers for FastAPI applications."""
|
"""Exception handlers for FastAPI applications."""
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import FastAPI, Request, Response, status
|
from fastapi import FastAPI, Request, Response, status
|
||||||
from fastapi.exceptions import RequestValidationError, ResponseValidationError
|
from fastapi.exceptions import (
|
||||||
from fastapi.openapi.utils import get_openapi
|
HTTPException,
|
||||||
|
RequestValidationError,
|
||||||
|
ResponseValidationError,
|
||||||
|
)
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
from ..schemas import ResponseStatus
|
from ..schemas import ErrorResponse, ResponseStatus
|
||||||
from .exceptions import ApiException
|
from .exceptions import ApiException
|
||||||
|
|
||||||
|
_VALIDATION_LOCATION_PARAMS: frozenset[str] = frozenset(
|
||||||
|
{"body", "query", "path", "header", "cookie"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def init_exceptions_handlers(app: FastAPI) -> FastAPI:
|
def init_exceptions_handlers(app: FastAPI) -> FastAPI:
|
||||||
|
"""Register exception handlers and custom OpenAPI schema on a FastAPI app.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI application instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The same FastAPI instance (for chaining).
|
||||||
|
"""
|
||||||
_register_exception_handlers(app)
|
_register_exception_handlers(app)
|
||||||
app.openapi = lambda: _custom_openapi(app) # type: ignore[method-assign]
|
_original_openapi = app.openapi
|
||||||
|
app.openapi = lambda: _patched_openapi(app, _original_openapi) # type: ignore[method-assign] # ty:ignore[invalid-assignment]
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
def _register_exception_handlers(app: FastAPI) -> None:
|
def _register_exception_handlers(app: FastAPI) -> None:
|
||||||
"""Register all exception handlers on a FastAPI application.
|
"""Register all exception handlers on a FastAPI application."""
|
||||||
|
|
||||||
Args:
|
|
||||||
app: FastAPI application instance
|
|
||||||
|
|
||||||
Example:
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
init_exceptions_handlers(app)
|
|
||||||
"""
|
|
||||||
|
|
||||||
@app.exception_handler(ApiException)
|
@app.exception_handler(ApiException)
|
||||||
async def api_exception_handler(request: Request, exc: ApiException) -> Response:
|
async def api_exception_handler(request: Request, exc: ApiException) -> Response:
|
||||||
"""Handle custom API exceptions with structured response."""
|
"""Handle custom API exceptions with structured response."""
|
||||||
api_error = exc.api_error
|
api_error = exc.api_error
|
||||||
|
error_response = ErrorResponse(
|
||||||
|
data=api_error.data,
|
||||||
|
message=api_error.msg,
|
||||||
|
description=api_error.desc,
|
||||||
|
error_code=api_error.err_code,
|
||||||
|
)
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=api_error.code,
|
status_code=api_error.code,
|
||||||
content={
|
content=error_response.model_dump(),
|
||||||
"data": None,
|
)
|
||||||
"status": ResponseStatus.FAIL.value,
|
|
||||||
"message": api_error.msg,
|
@app.exception_handler(HTTPException)
|
||||||
"description": api_error.desc,
|
async def http_exception_handler(request: Request, exc: HTTPException) -> Response:
|
||||||
"error_code": api_error.err_code,
|
"""Handle Starlette/FastAPI HTTPException with a consistent error format."""
|
||||||
},
|
detail = exc.detail if isinstance(exc.detail, str) else "HTTP Error"
|
||||||
|
error_response = ErrorResponse(
|
||||||
|
message=detail,
|
||||||
|
error_code=f"HTTP-{exc.status_code}",
|
||||||
|
)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content=error_response.model_dump(),
|
||||||
|
headers=getattr(exc, "headers", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
@app.exception_handler(RequestValidationError)
|
@app.exception_handler(RequestValidationError)
|
||||||
@@ -64,15 +83,14 @@ def _register_exception_handlers(app: FastAPI) -> None:
|
|||||||
@app.exception_handler(Exception)
|
@app.exception_handler(Exception)
|
||||||
async def generic_exception_handler(request: Request, exc: Exception) -> Response:
|
async def generic_exception_handler(request: Request, exc: Exception) -> Response:
|
||||||
"""Handle all unhandled exceptions with a generic 500 response."""
|
"""Handle all unhandled exceptions with a generic 500 response."""
|
||||||
|
error_response = ErrorResponse(
|
||||||
|
message="Internal Server Error",
|
||||||
|
description="An unexpected error occurred. Please try again later.",
|
||||||
|
error_code="SERVER-500",
|
||||||
|
)
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
content={
|
content=error_response.model_dump(),
|
||||||
"data": None,
|
|
||||||
"status": ResponseStatus.FAIL.value,
|
|
||||||
"message": "Internal Server Error",
|
|
||||||
"description": "An unexpected error occurred. Please try again later.",
|
|
||||||
"error_code": "SERVER-500",
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -84,11 +102,10 @@ def _format_validation_error(
|
|||||||
formatted_errors = []
|
formatted_errors = []
|
||||||
|
|
||||||
for error in errors:
|
for error in errors:
|
||||||
field_path = ".".join(
|
locs = error["loc"]
|
||||||
str(loc)
|
if locs and locs[0] in _VALIDATION_LOCATION_PARAMS:
|
||||||
for loc in error["loc"]
|
locs = locs[1:]
|
||||||
if loc not in ("body", "query", "path", "header", "cookie")
|
field_path = ".".join(str(loc) for loc in locs)
|
||||||
)
|
|
||||||
formatted_errors.append(
|
formatted_errors.append(
|
||||||
{
|
{
|
||||||
"field": field_path or "root",
|
"field": field_path or "root",
|
||||||
@@ -97,46 +114,35 @@ def _format_validation_error(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
error_response = ErrorResponse(
|
||||||
|
data={"errors": formatted_errors},
|
||||||
|
message="Validation Error",
|
||||||
|
description=f"{len(formatted_errors)} validation error(s) detected",
|
||||||
|
error_code="VAL-422",
|
||||||
|
)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
|
||||||
content={
|
content=error_response.model_dump(),
|
||||||
"data": {"errors": formatted_errors},
|
|
||||||
"status": ResponseStatus.FAIL.value,
|
|
||||||
"message": "Validation Error",
|
|
||||||
"description": f"{len(formatted_errors)} validation error(s) detected",
|
|
||||||
"error_code": "VAL-422",
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _custom_openapi(app: FastAPI) -> dict[str, Any]:
|
def _patched_openapi(
|
||||||
"""Generate custom OpenAPI schema with standardized error format.
|
app: FastAPI, original_openapi: Callable[[], dict[str, Any]]
|
||||||
|
) -> dict[str, Any]:
|
||||||
Replaces default 422 validation error responses with the custom format.
|
"""Generate the OpenAPI schema and replace default 422 responses.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
app: FastAPI application instance
|
app: FastAPI application instance.
|
||||||
|
original_openapi: The previous ``app.openapi`` callable to delegate to.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
OpenAPI schema dict
|
Patched OpenAPI schema dict.
|
||||||
|
|
||||||
Example:
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
init_exceptions_handlers(app) # Automatically sets custom OpenAPI
|
|
||||||
"""
|
"""
|
||||||
if app.openapi_schema:
|
if app.openapi_schema:
|
||||||
return app.openapi_schema
|
return app.openapi_schema
|
||||||
|
|
||||||
openapi_schema = get_openapi(
|
openapi_schema = original_openapi()
|
||||||
title=app.title,
|
|
||||||
version=app.version,
|
|
||||||
openapi_version=app.openapi_version,
|
|
||||||
description=app.description,
|
|
||||||
routes=app.routes,
|
|
||||||
)
|
|
||||||
|
|
||||||
for path_data in openapi_schema.get("paths", {}).values():
|
for path_data in openapi_schema.get("paths", {}).values():
|
||||||
for operation in path_data.values():
|
for operation in path_data.values():
|
||||||
@@ -146,20 +152,25 @@ def _custom_openapi(app: FastAPI) -> dict[str, Any]:
|
|||||||
"description": "Validation Error",
|
"description": "Validation Error",
|
||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"example": {
|
"examples": {
|
||||||
"data": {
|
"VAL-422": {
|
||||||
"errors": [
|
"summary": "Validation Error",
|
||||||
{
|
"value": {
|
||||||
"field": "field_name",
|
"data": {
|
||||||
"message": "value is not valid",
|
"errors": [
|
||||||
"type": "value_error",
|
{
|
||||||
}
|
"field": "field_name",
|
||||||
]
|
"message": "value is not valid",
|
||||||
},
|
"type": "value_error",
|
||||||
"status": ResponseStatus.FAIL.value,
|
}
|
||||||
"message": "Validation Error",
|
]
|
||||||
"description": "1 validation error(s) detected",
|
},
|
||||||
"error_code": "VAL-422",
|
"status": ResponseStatus.FAIL.value,
|
||||||
|
"message": "Validation Error",
|
||||||
|
"description": "1 validation error(s) detected",
|
||||||
|
"error_code": "VAL-422",
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -1,16 +1,14 @@
|
|||||||
from .fixtures import (
|
"""Fixture system for seeding databases with dependency resolution."""
|
||||||
Context,
|
|
||||||
FixtureRegistry,
|
from .enum import LoadStrategy
|
||||||
LoadStrategy,
|
from .registry import Context, FixtureRegistry
|
||||||
load_fixtures,
|
from .utils import get_obj_by_attr, load_fixtures, load_fixtures_by_context
|
||||||
load_fixtures_by_context,
|
|
||||||
)
|
|
||||||
from .pytest_plugin import register_fixtures
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Context",
|
"Context",
|
||||||
"FixtureRegistry",
|
"FixtureRegistry",
|
||||||
"LoadStrategy",
|
"LoadStrategy",
|
||||||
|
"get_obj_by_attr",
|
||||||
"load_fixtures",
|
"load_fixtures",
|
||||||
"load_fixtures_by_context",
|
"load_fixtures_by_context",
|
||||||
"register_fixtures",
|
"register_fixtures",
|
||||||
|
|||||||
32
src/fastapi_toolsets/fixtures/enum.py
Normal file
32
src/fastapi_toolsets/fixtures/enum.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""Enums for fixture loading strategies and contexts."""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class LoadStrategy(str, Enum):
|
||||||
|
"""Strategy for loading fixtures into the database."""
|
||||||
|
|
||||||
|
INSERT = "insert"
|
||||||
|
"""Insert new records. Fails if record already exists."""
|
||||||
|
|
||||||
|
MERGE = "merge"
|
||||||
|
"""Insert or update based on primary key (SQLAlchemy merge)."""
|
||||||
|
|
||||||
|
SKIP_EXISTING = "skip_existing"
|
||||||
|
"""Insert only if record doesn't exist (based on primary key)."""
|
||||||
|
|
||||||
|
|
||||||
|
class Context(str, Enum):
|
||||||
|
"""Predefined fixture contexts."""
|
||||||
|
|
||||||
|
BASE = "base"
|
||||||
|
"""Base fixtures loaded in all environments."""
|
||||||
|
|
||||||
|
PRODUCTION = "production"
|
||||||
|
"""Production-only fixtures."""
|
||||||
|
|
||||||
|
DEVELOPMENT = "development"
|
||||||
|
"""Development fixtures."""
|
||||||
|
|
||||||
|
TESTING = "testing"
|
||||||
|
"""Test fixtures."""
|
||||||
@@ -1,321 +0,0 @@
|
|||||||
"""Fixture system with dependency management and context support."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from collections.abc import Callable, Sequence
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, cast
|
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
|
||||||
|
|
||||||
from ..db import get_transaction
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class LoadStrategy(str, Enum):
|
|
||||||
"""Strategy for loading fixtures into the database."""
|
|
||||||
|
|
||||||
INSERT = "insert"
|
|
||||||
"""Insert new records. Fails if record already exists."""
|
|
||||||
|
|
||||||
MERGE = "merge"
|
|
||||||
"""Insert or update based on primary key (SQLAlchemy merge)."""
|
|
||||||
|
|
||||||
SKIP_EXISTING = "skip_existing"
|
|
||||||
"""Insert only if record doesn't exist (based on primary key)."""
|
|
||||||
|
|
||||||
|
|
||||||
class Context(str, Enum):
|
|
||||||
"""Predefined fixture contexts."""
|
|
||||||
|
|
||||||
BASE = "base"
|
|
||||||
"""Base fixtures loaded in all environments."""
|
|
||||||
|
|
||||||
PRODUCTION = "production"
|
|
||||||
"""Production-only fixtures."""
|
|
||||||
|
|
||||||
DEVELOPMENT = "development"
|
|
||||||
"""Development fixtures."""
|
|
||||||
|
|
||||||
TESTING = "testing"
|
|
||||||
"""Test fixtures."""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Fixture:
|
|
||||||
"""A fixture definition with metadata."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
func: Callable[[], Sequence[DeclarativeBase]]
|
|
||||||
depends_on: list[str] = field(default_factory=list)
|
|
||||||
contexts: list[str] = field(default_factory=lambda: [Context.BASE])
|
|
||||||
|
|
||||||
|
|
||||||
class FixtureRegistry:
|
|
||||||
"""Registry for managing fixtures with dependencies.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
from fastapi_toolsets.fixtures import FixtureRegistry, Context
|
|
||||||
|
|
||||||
fixtures = FixtureRegistry()
|
|
||||||
|
|
||||||
@fixtures.register
|
|
||||||
def roles():
|
|
||||||
return [
|
|
||||||
Role(id=1, name="admin"),
|
|
||||||
Role(id=2, name="user"),
|
|
||||||
]
|
|
||||||
|
|
||||||
@fixtures.register(depends_on=["roles"])
|
|
||||||
def users():
|
|
||||||
return [
|
|
||||||
User(id=1, username="admin", role_id=1),
|
|
||||||
]
|
|
||||||
|
|
||||||
@fixtures.register(depends_on=["users"], contexts=[Context.TESTING])
|
|
||||||
def test_data():
|
|
||||||
return [
|
|
||||||
Post(id=1, title="Test", user_id=1),
|
|
||||||
]
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._fixtures: dict[str, Fixture] = {}
|
|
||||||
|
|
||||||
def register(
|
|
||||||
self,
|
|
||||||
func: Callable[[], Sequence[DeclarativeBase]] | None = None,
|
|
||||||
*,
|
|
||||||
name: str | None = None,
|
|
||||||
depends_on: list[str] | None = None,
|
|
||||||
contexts: list[str | Context] | None = None,
|
|
||||||
) -> Callable[..., Any]:
|
|
||||||
"""Register a fixture function.
|
|
||||||
|
|
||||||
Can be used as a decorator with or without arguments.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
func: Fixture function returning list of model instances
|
|
||||||
name: Fixture name (defaults to function name)
|
|
||||||
depends_on: List of fixture names this depends on
|
|
||||||
contexts: List of contexts this fixture belongs to
|
|
||||||
|
|
||||||
Example:
|
|
||||||
@fixtures.register
|
|
||||||
def roles():
|
|
||||||
return [Role(id=1, name="admin")]
|
|
||||||
|
|
||||||
@fixtures.register(depends_on=["roles"], contexts=[Context.TESTING])
|
|
||||||
def test_users():
|
|
||||||
return [User(id=1, username="test", role_id=1)]
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(
|
|
||||||
fn: Callable[[], Sequence[DeclarativeBase]],
|
|
||||||
) -> Callable[[], Sequence[DeclarativeBase]]:
|
|
||||||
fixture_name = name or cast(Any, fn).__name__
|
|
||||||
fixture_contexts = [
|
|
||||||
c.value if isinstance(c, Context) else c
|
|
||||||
for c in (contexts or [Context.BASE])
|
|
||||||
]
|
|
||||||
|
|
||||||
self._fixtures[fixture_name] = Fixture(
|
|
||||||
name=fixture_name,
|
|
||||||
func=fn,
|
|
||||||
depends_on=depends_on or [],
|
|
||||||
contexts=fixture_contexts,
|
|
||||||
)
|
|
||||||
return fn
|
|
||||||
|
|
||||||
if func is not None:
|
|
||||||
return decorator(func)
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
def get(self, name: str) -> Fixture:
|
|
||||||
"""Get a fixture by name."""
|
|
||||||
if name not in self._fixtures:
|
|
||||||
raise KeyError(f"Fixture '{name}' not found")
|
|
||||||
return self._fixtures[name]
|
|
||||||
|
|
||||||
def get_all(self) -> list[Fixture]:
|
|
||||||
"""Get all registered fixtures."""
|
|
||||||
return list(self._fixtures.values())
|
|
||||||
|
|
||||||
def get_by_context(self, *contexts: str | Context) -> list[Fixture]:
|
|
||||||
"""Get fixtures for specific contexts."""
|
|
||||||
context_values = {c.value if isinstance(c, Context) else c for c in contexts}
|
|
||||||
return [f for f in self._fixtures.values() if set(f.contexts) & context_values]
|
|
||||||
|
|
||||||
def resolve_dependencies(self, *names: str) -> list[str]:
|
|
||||||
"""Resolve fixture dependencies in topological order.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
*names: Fixture names to resolve
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of fixture names in load order (dependencies first)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
KeyError: If a fixture is not found
|
|
||||||
ValueError: If circular dependency detected
|
|
||||||
"""
|
|
||||||
resolved: list[str] = []
|
|
||||||
seen: set[str] = set()
|
|
||||||
visiting: set[str] = set()
|
|
||||||
|
|
||||||
def visit(name: str) -> None:
|
|
||||||
if name in resolved:
|
|
||||||
return
|
|
||||||
if name in visiting:
|
|
||||||
raise ValueError(f"Circular dependency detected: {name}")
|
|
||||||
|
|
||||||
visiting.add(name)
|
|
||||||
fixture = self.get(name)
|
|
||||||
|
|
||||||
for dep in fixture.depends_on:
|
|
||||||
visit(dep)
|
|
||||||
|
|
||||||
visiting.remove(name)
|
|
||||||
resolved.append(name)
|
|
||||||
seen.add(name)
|
|
||||||
|
|
||||||
for name in names:
|
|
||||||
visit(name)
|
|
||||||
|
|
||||||
return resolved
|
|
||||||
|
|
||||||
def resolve_context_dependencies(self, *contexts: str | Context) -> list[str]:
|
|
||||||
"""Resolve all fixtures for contexts with dependencies.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
*contexts: Contexts to load
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of fixture names in load order
|
|
||||||
"""
|
|
||||||
context_fixtures = self.get_by_context(*contexts)
|
|
||||||
names = [f.name for f in context_fixtures]
|
|
||||||
|
|
||||||
all_deps: set[str] = set()
|
|
||||||
for name in names:
|
|
||||||
deps = self.resolve_dependencies(name)
|
|
||||||
all_deps.update(deps)
|
|
||||||
|
|
||||||
return self.resolve_dependencies(*all_deps)
|
|
||||||
|
|
||||||
|
|
||||||
async def load_fixtures(
|
|
||||||
session: AsyncSession,
|
|
||||||
registry: FixtureRegistry,
|
|
||||||
*names: str,
|
|
||||||
strategy: LoadStrategy = LoadStrategy.MERGE,
|
|
||||||
) -> dict[str, list[DeclarativeBase]]:
|
|
||||||
"""Load specific fixtures by name with dependencies.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: Database session
|
|
||||||
registry: Fixture registry
|
|
||||||
*names: Fixture names to load (dependencies auto-resolved)
|
|
||||||
strategy: How to handle existing records
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict mapping fixture names to loaded instances
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# Loads 'roles' first (dependency), then 'users'
|
|
||||||
result = await load_fixtures(session, fixtures, "users")
|
|
||||||
print(result["users"]) # [User(...), ...]
|
|
||||||
"""
|
|
||||||
ordered = registry.resolve_dependencies(*names)
|
|
||||||
return await _load_ordered(session, registry, ordered, strategy)
|
|
||||||
|
|
||||||
|
|
||||||
async def load_fixtures_by_context(
|
|
||||||
session: AsyncSession,
|
|
||||||
registry: FixtureRegistry,
|
|
||||||
*contexts: str | Context,
|
|
||||||
strategy: LoadStrategy = LoadStrategy.MERGE,
|
|
||||||
) -> dict[str, list[DeclarativeBase]]:
|
|
||||||
"""Load all fixtures for specific contexts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: Database session
|
|
||||||
registry: Fixture registry
|
|
||||||
*contexts: Contexts to load (e.g., Context.BASE, Context.TESTING)
|
|
||||||
strategy: How to handle existing records
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict mapping fixture names to loaded instances
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# Load base + testing fixtures
|
|
||||||
await load_fixtures_by_context(
|
|
||||||
session, fixtures,
|
|
||||||
Context.BASE, Context.TESTING
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
ordered = registry.resolve_context_dependencies(*contexts)
|
|
||||||
return await _load_ordered(session, registry, ordered, strategy)
|
|
||||||
|
|
||||||
|
|
||||||
async def _load_ordered(
|
|
||||||
session: AsyncSession,
|
|
||||||
registry: FixtureRegistry,
|
|
||||||
ordered_names: list[str],
|
|
||||||
strategy: LoadStrategy,
|
|
||||||
) -> dict[str, list[DeclarativeBase]]:
|
|
||||||
"""Load fixtures in order."""
|
|
||||||
results: dict[str, list[DeclarativeBase]] = {}
|
|
||||||
|
|
||||||
for name in ordered_names:
|
|
||||||
fixture = registry.get(name)
|
|
||||||
instances = list(fixture.func())
|
|
||||||
|
|
||||||
if not instances:
|
|
||||||
results[name] = []
|
|
||||||
continue
|
|
||||||
|
|
||||||
model_name = type(instances[0]).__name__
|
|
||||||
loaded: list[DeclarativeBase] = []
|
|
||||||
|
|
||||||
async with get_transaction(session):
|
|
||||||
for instance in instances:
|
|
||||||
if strategy == LoadStrategy.INSERT:
|
|
||||||
session.add(instance)
|
|
||||||
loaded.append(instance)
|
|
||||||
|
|
||||||
elif strategy == LoadStrategy.MERGE:
|
|
||||||
merged = await session.merge(instance)
|
|
||||||
loaded.append(merged)
|
|
||||||
|
|
||||||
elif strategy == LoadStrategy.SKIP_EXISTING:
|
|
||||||
pk = _get_primary_key(instance)
|
|
||||||
if pk is not None:
|
|
||||||
existing = await session.get(type(instance), pk)
|
|
||||||
if existing is None:
|
|
||||||
session.add(instance)
|
|
||||||
loaded.append(instance)
|
|
||||||
else:
|
|
||||||
session.add(instance)
|
|
||||||
loaded.append(instance)
|
|
||||||
|
|
||||||
results[name] = loaded
|
|
||||||
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
|
|
||||||
"""Get the primary key value of a model instance."""
|
|
||||||
mapper = instance.__class__.__mapper__
|
|
||||||
pk_cols = mapper.primary_key
|
|
||||||
|
|
||||||
if len(pk_cols) == 1:
|
|
||||||
return getattr(instance, pk_cols[0].name, None)
|
|
||||||
|
|
||||||
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
|
|
||||||
if all(v is not None for v in pk_values):
|
|
||||||
return pk_values
|
|
||||||
return None
|
|
||||||
311
src/fastapi_toolsets/fixtures/registry.py
Normal file
311
src/fastapi_toolsets/fixtures/registry.py
Normal file
@@ -0,0 +1,311 @@
|
|||||||
|
"""Fixture system with dependency management and context support."""
|
||||||
|
|
||||||
|
from collections.abc import Callable, Sequence
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from ..logger import get_logger
|
||||||
|
from .enum import Context
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_contexts(
|
||||||
|
contexts: list[str | Enum] | tuple[str | Enum, ...],
|
||||||
|
) -> list[str]:
|
||||||
|
"""Convert a sequence of any Enum subclass and/or plain strings to a list of strings."""
|
||||||
|
return [c.value if isinstance(c, Enum) else c for c in contexts]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Fixture:
|
||||||
|
"""A fixture definition with metadata."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
func: Callable[[], Sequence[DeclarativeBase]]
|
||||||
|
depends_on: list[str] = field(default_factory=list)
|
||||||
|
contexts: list[str] = field(default_factory=lambda: [Context.BASE])
|
||||||
|
|
||||||
|
|
||||||
|
class FixtureRegistry:
|
||||||
|
"""Registry for managing fixtures with dependencies.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.fixtures import FixtureRegistry, Context
|
||||||
|
|
||||||
|
fixtures = FixtureRegistry()
|
||||||
|
|
||||||
|
@fixtures.register
|
||||||
|
def roles():
|
||||||
|
return [
|
||||||
|
Role(id=1, name="admin"),
|
||||||
|
Role(id=2, name="user"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@fixtures.register(depends_on=["roles"])
|
||||||
|
def users():
|
||||||
|
return [
|
||||||
|
User(id=1, username="admin", role_id=1),
|
||||||
|
]
|
||||||
|
|
||||||
|
@fixtures.register(depends_on=["users"], contexts=[Context.TESTING])
|
||||||
|
def test_data():
|
||||||
|
return [
|
||||||
|
Post(id=1, title="Test", user_id=1),
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
Fixtures with the same name may be registered for **different** contexts.
|
||||||
|
When multiple contexts are loaded together, their instances are merged:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@fixtures.register(contexts=[Context.BASE])
|
||||||
|
def users():
|
||||||
|
return [User(id=1, username="admin")]
|
||||||
|
|
||||||
|
@fixtures.register(contexts=[Context.TESTING])
|
||||||
|
def users():
|
||||||
|
return [User(id=2, username="tester")]
|
||||||
|
# load_fixtures_by_context(..., Context.BASE, Context.TESTING)
|
||||||
|
# → loads both User(admin) and User(tester) under the "users" name
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
contexts: list[str | Enum] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._fixtures: dict[str, list[Fixture]] = {}
|
||||||
|
self._default_contexts: list[str] | None = (
|
||||||
|
_normalize_contexts(contexts) if contexts else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def _validate_no_context_overlap(self, name: str, new_contexts: list[str]) -> None:
|
||||||
|
"""Raise ``ValueError`` if any existing variant for *name* overlaps."""
|
||||||
|
existing_variants = self._fixtures.get(name, [])
|
||||||
|
new_set = set(new_contexts)
|
||||||
|
for variant in existing_variants:
|
||||||
|
if set(variant.contexts) & new_set:
|
||||||
|
raise ValueError(
|
||||||
|
f"Fixture '{name}' already exists in the current registry "
|
||||||
|
f"with overlapping contexts. Use distinct context sets for "
|
||||||
|
f"each variant of the same fixture name."
|
||||||
|
)
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self,
|
||||||
|
func: Callable[[], Sequence[DeclarativeBase]] | None = None,
|
||||||
|
*,
|
||||||
|
name: str | None = None,
|
||||||
|
depends_on: list[str] | None = None,
|
||||||
|
contexts: list[str | Enum] | None = None,
|
||||||
|
) -> Callable[..., Any]:
|
||||||
|
"""Register a fixture function.
|
||||||
|
|
||||||
|
Can be used as a decorator with or without arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: Fixture function returning list of model instances
|
||||||
|
name: Fixture name (defaults to function name)
|
||||||
|
depends_on: List of fixture names this depends on
|
||||||
|
contexts: List of contexts this fixture belongs to. Both
|
||||||
|
:class:`Context` enum values and plain strings are accepted.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
@fixtures.register
|
||||||
|
def roles():
|
||||||
|
return [Role(id=1, name="admin")]
|
||||||
|
|
||||||
|
@fixtures.register(depends_on=["roles"], contexts=[Context.TESTING])
|
||||||
|
def test_users():
|
||||||
|
return [User(id=1, username="test", role_id=1)]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(
|
||||||
|
fn: Callable[[], Sequence[DeclarativeBase]],
|
||||||
|
) -> Callable[[], Sequence[DeclarativeBase]]:
|
||||||
|
fixture_name = name or cast(Any, fn).__name__
|
||||||
|
if contexts is not None:
|
||||||
|
fixture_contexts = _normalize_contexts(contexts)
|
||||||
|
elif self._default_contexts is not None:
|
||||||
|
fixture_contexts = self._default_contexts
|
||||||
|
else:
|
||||||
|
fixture_contexts = [Context.BASE.value]
|
||||||
|
|
||||||
|
self._validate_no_context_overlap(fixture_name, fixture_contexts)
|
||||||
|
self._fixtures.setdefault(fixture_name, []).append(
|
||||||
|
Fixture(
|
||||||
|
name=fixture_name,
|
||||||
|
func=fn,
|
||||||
|
depends_on=depends_on or [],
|
||||||
|
contexts=fixture_contexts,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return fn
|
||||||
|
|
||||||
|
if func is not None:
|
||||||
|
return decorator(func)
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def include_registry(self, registry: "FixtureRegistry") -> None:
|
||||||
|
"""Include another `FixtureRegistry` in the same current `FixtureRegistry`.
|
||||||
|
|
||||||
|
Fixtures with the same name are allowed as long as their context sets
|
||||||
|
do not overlap. Conflicting contexts raise :class:`ValueError`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registry: The `FixtureRegistry` to include
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If a fixture name already exists with overlapping contexts
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
dev_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@dev_registry.register
|
||||||
|
def dev_data():
|
||||||
|
return [...]
|
||||||
|
|
||||||
|
registry.include_registry(registry=dev_registry)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
for name, variants in registry._fixtures.items():
|
||||||
|
for fixture in variants:
|
||||||
|
self._validate_no_context_overlap(name, fixture.contexts)
|
||||||
|
self._fixtures.setdefault(name, []).append(fixture)
|
||||||
|
|
||||||
|
def get(self, name: str) -> Fixture:
|
||||||
|
"""Get a fixture by name.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If no fixture with *name* is registered.
|
||||||
|
ValueError: If the fixture has multiple context variants — use
|
||||||
|
:meth:`get_variants` in that case.
|
||||||
|
"""
|
||||||
|
if name not in self._fixtures:
|
||||||
|
raise KeyError(f"Fixture '{name}' not found")
|
||||||
|
variants = self._fixtures[name]
|
||||||
|
if len(variants) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Fixture '{name}' has {len(variants)} context variants. "
|
||||||
|
f"Use get_variants('{name}') to retrieve them."
|
||||||
|
)
|
||||||
|
return variants[0]
|
||||||
|
|
||||||
|
def get_variants(self, name: str, *contexts: str | Enum) -> list[Fixture]:
|
||||||
|
"""Return all registered variants for *name*, optionally filtered by context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Fixture name.
|
||||||
|
*contexts: If given, only return variants whose context set
|
||||||
|
intersects with these values. Both :class:`Context` enum
|
||||||
|
values and plain strings are accepted.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching :class:`Fixture` objects (may be empty when a
|
||||||
|
context filter is applied and nothing matches).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If no fixture with *name* is registered.
|
||||||
|
"""
|
||||||
|
if name not in self._fixtures:
|
||||||
|
raise KeyError(f"Fixture '{name}' not found")
|
||||||
|
variants = self._fixtures[name]
|
||||||
|
if not contexts:
|
||||||
|
return list(variants)
|
||||||
|
context_values = set(_normalize_contexts(contexts))
|
||||||
|
return [v for v in variants if set(v.contexts) & context_values]
|
||||||
|
|
||||||
|
def get_all(self) -> list[Fixture]:
|
||||||
|
"""Get all registered fixtures (all variants of all names)."""
|
||||||
|
return [f for variants in self._fixtures.values() for f in variants]
|
||||||
|
|
||||||
|
def get_by_context(self, *contexts: str | Enum) -> list[Fixture]:
|
||||||
|
"""Get fixtures for specific contexts."""
|
||||||
|
context_values = set(_normalize_contexts(contexts))
|
||||||
|
return [
|
||||||
|
f
|
||||||
|
for variants in self._fixtures.values()
|
||||||
|
for f in variants
|
||||||
|
if set(f.contexts) & context_values
|
||||||
|
]
|
||||||
|
|
||||||
|
def resolve_dependencies(self, *names: str) -> list[str]:
|
||||||
|
"""Resolve fixture dependencies in topological order.
|
||||||
|
|
||||||
|
When a fixture name has multiple context variants, the union of all
|
||||||
|
variants' ``depends_on`` lists is used.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*names: Fixture names to resolve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of fixture names in load order (dependencies first)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If a fixture is not found
|
||||||
|
ValueError: If circular dependency detected
|
||||||
|
"""
|
||||||
|
resolved: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
visiting: set[str] = set()
|
||||||
|
|
||||||
|
def visit(name: str) -> None:
|
||||||
|
if name in resolved:
|
||||||
|
return
|
||||||
|
if name in visiting:
|
||||||
|
raise ValueError(f"Circular dependency detected: {name}")
|
||||||
|
|
||||||
|
visiting.add(name)
|
||||||
|
variants = self._fixtures.get(name)
|
||||||
|
if variants is None:
|
||||||
|
raise KeyError(f"Fixture '{name}' not found")
|
||||||
|
|
||||||
|
# Union of depends_on across all variants, preserving first-seen order.
|
||||||
|
seen_deps: set[str] = set()
|
||||||
|
all_deps: list[str] = []
|
||||||
|
for variant in variants:
|
||||||
|
for dep in variant.depends_on:
|
||||||
|
if dep not in seen_deps:
|
||||||
|
all_deps.append(dep)
|
||||||
|
seen_deps.add(dep)
|
||||||
|
|
||||||
|
for dep in all_deps:
|
||||||
|
visit(dep)
|
||||||
|
|
||||||
|
visiting.remove(name)
|
||||||
|
resolved.append(name)
|
||||||
|
seen.add(name)
|
||||||
|
|
||||||
|
for name in names:
|
||||||
|
visit(name)
|
||||||
|
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
def resolve_context_dependencies(self, *contexts: str | Enum) -> list[str]:
|
||||||
|
"""Resolve all fixtures for contexts with dependencies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*contexts: Contexts to load
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of fixture names in load order
|
||||||
|
"""
|
||||||
|
context_fixtures = self.get_by_context(*contexts)
|
||||||
|
# Deduplicate names while preserving first-seen order (a name can
|
||||||
|
# appear multiple times if it has variants in different contexts).
|
||||||
|
names = list(dict.fromkeys(f.name for f in context_fixtures))
|
||||||
|
|
||||||
|
all_deps: set[str] = set()
|
||||||
|
for name in names:
|
||||||
|
deps = self.resolve_dependencies(name)
|
||||||
|
all_deps.update(deps)
|
||||||
|
|
||||||
|
return self.resolve_dependencies(*all_deps)
|
||||||
298
src/fastapi_toolsets/fixtures/utils.py
Normal file
298
src/fastapi_toolsets/fixtures/utils.py
Normal file
@@ -0,0 +1,298 @@
|
|||||||
|
"""Fixture loading utilities for database seeding."""
|
||||||
|
|
||||||
|
from collections.abc import Callable, Sequence
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import inspect as sa_inspect
|
||||||
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from ..db import get_transaction
|
||||||
|
from ..logger import get_logger
|
||||||
|
from ..types import ModelType
|
||||||
|
from .enum import LoadStrategy
|
||||||
|
from .registry import FixtureRegistry, _normalize_contexts
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _instance_to_dict(instance: DeclarativeBase) -> dict[str, Any]:
|
||||||
|
"""Extract column values from a model instance, skipping unset server-default columns."""
|
||||||
|
state = sa_inspect(instance)
|
||||||
|
state_dict = state.dict
|
||||||
|
result: dict[str, Any] = {}
|
||||||
|
for prop in state.mapper.column_attrs:
|
||||||
|
if prop.key not in state_dict:
|
||||||
|
continue
|
||||||
|
val = state_dict[prop.key]
|
||||||
|
if val is None:
|
||||||
|
col = prop.columns[0]
|
||||||
|
|
||||||
|
if (
|
||||||
|
col.server_default is not None
|
||||||
|
or (col.default is not None and col.default.is_callable)
|
||||||
|
or col.autoincrement is True
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
result[prop.key] = val
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _group_by_type(
|
||||||
|
instances: list[DeclarativeBase],
|
||||||
|
) -> list[tuple[type[DeclarativeBase], list[DeclarativeBase]]]:
|
||||||
|
"""Group instances by their concrete model class, preserving insertion order."""
|
||||||
|
groups: dict[type[DeclarativeBase], list[DeclarativeBase]] = {}
|
||||||
|
for instance in instances:
|
||||||
|
groups.setdefault(type(instance), []).append(instance)
|
||||||
|
return list(groups.items())
|
||||||
|
|
||||||
|
|
||||||
|
def _group_by_column_set(
|
||||||
|
dicts: list[dict[str, Any]],
|
||||||
|
instances: list[DeclarativeBase],
|
||||||
|
) -> list[tuple[list[dict[str, Any]], list[DeclarativeBase]]]:
|
||||||
|
"""Group (dict, instance) pairs by their dict key sets."""
|
||||||
|
groups: dict[
|
||||||
|
frozenset[str], tuple[list[dict[str, Any]], list[DeclarativeBase]]
|
||||||
|
] = {}
|
||||||
|
for d, inst in zip(dicts, instances):
|
||||||
|
key = frozenset(d)
|
||||||
|
if key not in groups:
|
||||||
|
groups[key] = ([], [])
|
||||||
|
groups[key][0].append(d)
|
||||||
|
groups[key][1].append(inst)
|
||||||
|
return list(groups.values())
|
||||||
|
|
||||||
|
|
||||||
|
async def _batch_insert(
|
||||||
|
session: AsyncSession,
|
||||||
|
model_cls: type[DeclarativeBase],
|
||||||
|
instances: list[DeclarativeBase],
|
||||||
|
) -> None:
|
||||||
|
"""INSERT all instances — raises on conflict (no duplicate handling)."""
|
||||||
|
dicts = [_instance_to_dict(i) for i in instances]
|
||||||
|
for group_dicts, _ in _group_by_column_set(dicts, instances):
|
||||||
|
await session.execute(pg_insert(model_cls).values(group_dicts))
|
||||||
|
|
||||||
|
|
||||||
|
async def _batch_merge(
|
||||||
|
session: AsyncSession,
|
||||||
|
model_cls: type[DeclarativeBase],
|
||||||
|
instances: list[DeclarativeBase],
|
||||||
|
) -> None:
|
||||||
|
"""UPSERT: insert new rows, update existing ones with the provided values."""
|
||||||
|
mapper = model_cls.__mapper__
|
||||||
|
pk_names = [col.name for col in mapper.primary_key]
|
||||||
|
pk_names_set = set(pk_names)
|
||||||
|
non_pk_cols = [
|
||||||
|
prop.key
|
||||||
|
for prop in mapper.column_attrs
|
||||||
|
if not any(col.name in pk_names_set for col in prop.columns)
|
||||||
|
]
|
||||||
|
|
||||||
|
dicts = [_instance_to_dict(i) for i in instances]
|
||||||
|
for group_dicts, _ in _group_by_column_set(dicts, instances):
|
||||||
|
stmt = pg_insert(model_cls).values(group_dicts)
|
||||||
|
|
||||||
|
inserted_keys = set(group_dicts[0])
|
||||||
|
update_cols = [col for col in non_pk_cols if col in inserted_keys]
|
||||||
|
|
||||||
|
if update_cols:
|
||||||
|
stmt = stmt.on_conflict_do_update(
|
||||||
|
index_elements=pk_names,
|
||||||
|
set_={col: stmt.excluded[col] for col in update_cols},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
stmt = stmt.on_conflict_do_nothing(index_elements=pk_names)
|
||||||
|
|
||||||
|
await session.execute(stmt)
|
||||||
|
|
||||||
|
|
||||||
|
async def _batch_skip_existing(
|
||||||
|
session: AsyncSession,
|
||||||
|
model_cls: type[DeclarativeBase],
|
||||||
|
instances: list[DeclarativeBase],
|
||||||
|
) -> list[DeclarativeBase]:
|
||||||
|
"""INSERT only rows that do not already exist; return the inserted ones."""
|
||||||
|
mapper = model_cls.__mapper__
|
||||||
|
pk_names = [col.name for col in mapper.primary_key]
|
||||||
|
|
||||||
|
no_pk: list[DeclarativeBase] = []
|
||||||
|
with_pk_pairs: list[tuple[DeclarativeBase, Any]] = []
|
||||||
|
for inst in instances:
|
||||||
|
pk = _get_primary_key(inst)
|
||||||
|
if pk is None:
|
||||||
|
no_pk.append(inst)
|
||||||
|
else:
|
||||||
|
with_pk_pairs.append((inst, pk))
|
||||||
|
|
||||||
|
loaded: list[DeclarativeBase] = list(no_pk)
|
||||||
|
if no_pk:
|
||||||
|
no_pk_dicts = [_instance_to_dict(i) for i in no_pk]
|
||||||
|
for group_dicts, _ in _group_by_column_set(no_pk_dicts, no_pk):
|
||||||
|
await session.execute(pg_insert(model_cls).values(group_dicts))
|
||||||
|
|
||||||
|
if with_pk_pairs:
|
||||||
|
with_pk = [i for i, _ in with_pk_pairs]
|
||||||
|
with_pk_dicts = [_instance_to_dict(i) for i in with_pk]
|
||||||
|
for group_dicts, group_insts in _group_by_column_set(with_pk_dicts, with_pk):
|
||||||
|
stmt = (
|
||||||
|
pg_insert(model_cls)
|
||||||
|
.values(group_dicts)
|
||||||
|
.on_conflict_do_nothing(index_elements=pk_names)
|
||||||
|
)
|
||||||
|
result = await session.execute(stmt.returning(*mapper.primary_key))
|
||||||
|
inserted_pks = {
|
||||||
|
row[0] if len(pk_names) == 1 else tuple(row) for row in result
|
||||||
|
}
|
||||||
|
loaded.extend(
|
||||||
|
inst
|
||||||
|
for inst, pk in zip(
|
||||||
|
group_insts, [_get_primary_key(i) for i in group_insts]
|
||||||
|
)
|
||||||
|
if pk in inserted_pks
|
||||||
|
)
|
||||||
|
|
||||||
|
return loaded
|
||||||
|
|
||||||
|
|
||||||
|
async def _load_ordered(
|
||||||
|
session: AsyncSession,
|
||||||
|
registry: FixtureRegistry,
|
||||||
|
ordered_names: list[str],
|
||||||
|
strategy: LoadStrategy,
|
||||||
|
contexts: tuple[str, ...] | None = None,
|
||||||
|
) -> dict[str, list[DeclarativeBase]]:
|
||||||
|
"""Load fixtures in order using batch Core INSERT statements."""
|
||||||
|
results: dict[str, list[DeclarativeBase]] = {}
|
||||||
|
|
||||||
|
for name in ordered_names:
|
||||||
|
variants = (
|
||||||
|
registry.get_variants(name, *contexts)
|
||||||
|
if contexts is not None
|
||||||
|
else registry.get_variants(name)
|
||||||
|
)
|
||||||
|
|
||||||
|
if contexts is not None and not variants:
|
||||||
|
variants = registry.get_variants(name)
|
||||||
|
|
||||||
|
if not variants:
|
||||||
|
results[name] = []
|
||||||
|
continue
|
||||||
|
|
||||||
|
instances = [inst for v in variants for inst in v.func()]
|
||||||
|
|
||||||
|
if not instances:
|
||||||
|
results[name] = []
|
||||||
|
continue
|
||||||
|
|
||||||
|
model_name = type(instances[0]).__name__
|
||||||
|
loaded: list[DeclarativeBase] = []
|
||||||
|
|
||||||
|
async with get_transaction(session):
|
||||||
|
for model_cls, group in _group_by_type(instances):
|
||||||
|
match strategy:
|
||||||
|
case LoadStrategy.INSERT:
|
||||||
|
await _batch_insert(session, model_cls, group)
|
||||||
|
loaded.extend(group)
|
||||||
|
case LoadStrategy.MERGE:
|
||||||
|
await _batch_merge(session, model_cls, group)
|
||||||
|
loaded.extend(group)
|
||||||
|
case LoadStrategy.SKIP_EXISTING:
|
||||||
|
inserted = await _batch_skip_existing(session, model_cls, group)
|
||||||
|
loaded.extend(inserted)
|
||||||
|
|
||||||
|
results[name] = loaded
|
||||||
|
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
|
||||||
|
"""Get the primary key value of a model instance."""
|
||||||
|
mapper = instance.__class__.__mapper__
|
||||||
|
pk_cols = mapper.primary_key
|
||||||
|
|
||||||
|
if len(pk_cols) == 1:
|
||||||
|
return getattr(instance, pk_cols[0].name, None)
|
||||||
|
|
||||||
|
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
|
||||||
|
if all(v is not None for v in pk_values):
|
||||||
|
return pk_values
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_obj_by_attr(
|
||||||
|
fixtures: Callable[[], Sequence[ModelType]], attr_name: str, value: Any
|
||||||
|
) -> ModelType:
|
||||||
|
"""Get a SQLAlchemy model instance by matching an attribute value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fixtures: A fixture function registered via ``@registry.register``
|
||||||
|
that returns a sequence of SQLAlchemy model instances.
|
||||||
|
attr_name: Name of the attribute to match against.
|
||||||
|
value: Value to match.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The first model instance where the attribute matches the given value.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
StopIteration: If no matching object is found in the fixture group.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return next(obj for obj in fixtures() if getattr(obj, attr_name) == value)
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration(
|
||||||
|
f"No object with {attr_name}={value} found in fixture '{getattr(fixtures, '__name__', repr(fixtures))}'"
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
async def load_fixtures(
|
||||||
|
session: AsyncSession,
|
||||||
|
registry: FixtureRegistry,
|
||||||
|
*names: str,
|
||||||
|
strategy: LoadStrategy = LoadStrategy.MERGE,
|
||||||
|
) -> dict[str, list[DeclarativeBase]]:
|
||||||
|
"""Load specific fixtures by name with dependencies.
|
||||||
|
|
||||||
|
All context variants of each requested fixture are loaded and merged.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
registry: Fixture registry
|
||||||
|
*names: Fixture names to load (dependencies auto-resolved)
|
||||||
|
strategy: How to handle existing records
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping fixture names to loaded instances
|
||||||
|
"""
|
||||||
|
ordered = registry.resolve_dependencies(*names)
|
||||||
|
return await _load_ordered(session, registry, ordered, strategy)
|
||||||
|
|
||||||
|
|
||||||
|
async def load_fixtures_by_context(
|
||||||
|
session: AsyncSession,
|
||||||
|
registry: FixtureRegistry,
|
||||||
|
*contexts: str | Enum,
|
||||||
|
strategy: LoadStrategy = LoadStrategy.MERGE,
|
||||||
|
) -> dict[str, list[DeclarativeBase]]:
|
||||||
|
"""Load all fixtures for specific contexts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
registry: Fixture registry
|
||||||
|
*contexts: Contexts to load (e.g., ``Context.BASE``, ``Context.TESTING``,
|
||||||
|
or plain strings for custom contexts)
|
||||||
|
strategy: How to handle existing records
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping fixture names to loaded instances
|
||||||
|
"""
|
||||||
|
context_strings = tuple(_normalize_contexts(contexts))
|
||||||
|
ordered = registry.resolve_context_dependencies(*contexts)
|
||||||
|
return await _load_ordered(
|
||||||
|
session, registry, ordered, strategy, contexts=context_strings
|
||||||
|
)
|
||||||
98
src/fastapi_toolsets/logger.py
Normal file
98
src/fastapi_toolsets/logger.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""Logging configuration for FastAPI applications and CLI tools."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
__all__ = ["LogLevel", "configure_logging", "get_logger"]
|
||||||
|
|
||||||
|
DEFAULT_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
UVICORN_LOGGERS = ("uvicorn", "uvicorn.access", "uvicorn.error")
|
||||||
|
|
||||||
|
LogLevel = Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
||||||
|
|
||||||
|
|
||||||
|
def configure_logging(
|
||||||
|
level: LogLevel | int = "INFO",
|
||||||
|
fmt: str = DEFAULT_FORMAT,
|
||||||
|
logger_name: str | None = None,
|
||||||
|
) -> logging.Logger:
|
||||||
|
"""Configure logging with a stdout handler and consistent format.
|
||||||
|
|
||||||
|
Sets up a :class:`~logging.StreamHandler` writing to stdout with the
|
||||||
|
given format and level. Also configures the uvicorn loggers so that
|
||||||
|
FastAPI access logs use the same format.
|
||||||
|
|
||||||
|
Calling this function multiple times is safe -- existing handlers are
|
||||||
|
replaced rather than duplicated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: Log level (e.g. ``"DEBUG"``, ``"INFO"``, or ``logging.DEBUG``).
|
||||||
|
fmt: Log format string. Defaults to
|
||||||
|
``"%(asctime)s - %(name)s - %(levelname)s - %(message)s"``.
|
||||||
|
logger_name: Logger name to configure. ``None`` (the default)
|
||||||
|
configures the root logger so all loggers inherit the settings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The configured Logger instance.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.logger import configure_logging
|
||||||
|
|
||||||
|
logger = configure_logging("DEBUG")
|
||||||
|
logger.info("Application started")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
formatter = logging.Formatter(fmt)
|
||||||
|
|
||||||
|
handler = logging.StreamHandler(sys.stdout)
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
logger = logging.getLogger(logger_name)
|
||||||
|
logger.handlers.clear()
|
||||||
|
logger.addHandler(handler)
|
||||||
|
logger.setLevel(level)
|
||||||
|
|
||||||
|
for name in UVICORN_LOGGERS:
|
||||||
|
uv_logger = logging.getLogger(name)
|
||||||
|
uv_logger.handlers.clear()
|
||||||
|
uv_logger.addHandler(handler)
|
||||||
|
uv_logger.setLevel(level)
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
_SENTINEL = object()
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name: str | None = _SENTINEL) -> logging.Logger: # type: ignore[assignment] # ty:ignore[invalid-parameter-default]
|
||||||
|
"""Return a logger with the given *name*.
|
||||||
|
|
||||||
|
A thin convenience wrapper around :func:`logging.getLogger` that keeps
|
||||||
|
logging imports consistent across the codebase.
|
||||||
|
|
||||||
|
When called without arguments, the caller's ``__name__`` is used
|
||||||
|
automatically, so ``get_logger()`` in a module is equivalent to
|
||||||
|
``logging.getLogger(__name__)``. Pass ``None`` explicitly to get the
|
||||||
|
root logger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Logger name. Defaults to the caller's ``__name__``.
|
||||||
|
Pass ``None`` to get the root logger.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Logger instance.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger() # uses caller's __name__
|
||||||
|
logger = get_logger("myapp") # explicit name
|
||||||
|
logger = get_logger(None) # root logger
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if name is _SENTINEL:
|
||||||
|
name = sys._getframe(1).f_globals.get("__name__")
|
||||||
|
return logging.getLogger(name)
|
||||||
21
src/fastapi_toolsets/metrics/__init__.py
Normal file
21
src/fastapi_toolsets/metrics/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""Prometheus metrics integration for FastAPI applications."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .registry import Metric, MetricsRegistry
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .handler import init_metrics
|
||||||
|
except ImportError:
|
||||||
|
|
||||||
|
def init_metrics(*_args: Any, **_kwargs: Any) -> None:
|
||||||
|
from .._imports import require_extra
|
||||||
|
|
||||||
|
require_extra(package="prometheus_client", extra="metrics")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Metric",
|
||||||
|
"MetricsRegistry",
|
||||||
|
"init_metrics",
|
||||||
|
]
|
||||||
81
src/fastapi_toolsets/metrics/handler.py
Normal file
81
src/fastapi_toolsets/metrics/handler.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
"""Prometheus metrics endpoint for FastAPI applications."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import os
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.responses import Response
|
||||||
|
from prometheus_client import (
|
||||||
|
CONTENT_TYPE_LATEST,
|
||||||
|
CollectorRegistry,
|
||||||
|
generate_latest,
|
||||||
|
multiprocess,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..logger import get_logger
|
||||||
|
from .registry import MetricsRegistry
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_multiprocess() -> bool:
|
||||||
|
"""Check if prometheus multi-process mode is enabled."""
|
||||||
|
return "PROMETHEUS_MULTIPROC_DIR" in os.environ
|
||||||
|
|
||||||
|
|
||||||
|
def init_metrics(
|
||||||
|
app: FastAPI,
|
||||||
|
registry: MetricsRegistry,
|
||||||
|
*,
|
||||||
|
path: str = "/metrics",
|
||||||
|
) -> FastAPI:
|
||||||
|
"""Register a Prometheus ``/metrics`` endpoint on a FastAPI app.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI application instance.
|
||||||
|
registry: A :class:`MetricsRegistry` containing providers and collectors.
|
||||||
|
path: URL path for the metrics endpoint (default ``/metrics``).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The same FastAPI instance (for chaining).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi_toolsets.metrics import MetricsRegistry, init_metrics
|
||||||
|
|
||||||
|
metrics = MetricsRegistry()
|
||||||
|
app = FastAPI()
|
||||||
|
init_metrics(app, registry=metrics)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
for provider in registry.get_providers():
|
||||||
|
logger.debug("Initialising metric provider '%s'", provider.name)
|
||||||
|
registry._instances[provider.name] = provider.func()
|
||||||
|
|
||||||
|
# Partition collectors and cache env check at startup — both are stable for the app lifetime.
|
||||||
|
async_collectors = [
|
||||||
|
c for c in registry.get_collectors() if inspect.iscoroutinefunction(c.func)
|
||||||
|
]
|
||||||
|
sync_collectors = [
|
||||||
|
c for c in registry.get_collectors() if not inspect.iscoroutinefunction(c.func)
|
||||||
|
]
|
||||||
|
multiprocess_mode = _is_multiprocess()
|
||||||
|
|
||||||
|
@app.get(path, include_in_schema=False)
|
||||||
|
async def metrics_endpoint() -> Response:
|
||||||
|
for collector in sync_collectors:
|
||||||
|
collector.func()
|
||||||
|
for collector in async_collectors:
|
||||||
|
await collector.func()
|
||||||
|
|
||||||
|
if multiprocess_mode:
|
||||||
|
prom_registry = CollectorRegistry()
|
||||||
|
multiprocess.MultiProcessCollector(prom_registry)
|
||||||
|
output = generate_latest(prom_registry)
|
||||||
|
else:
|
||||||
|
output = generate_latest()
|
||||||
|
|
||||||
|
return Response(content=output, media_type=CONTENT_TYPE_LATEST)
|
||||||
|
|
||||||
|
return app
|
||||||
104
src/fastapi_toolsets/metrics/registry.py
Normal file
104
src/fastapi_toolsets/metrics/registry.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
"""Metrics registry with decorator-based registration."""
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from ..logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Metric:
|
||||||
|
"""A metric definition with metadata."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
func: Callable[..., Any]
|
||||||
|
collect: bool = field(default=False)
|
||||||
|
|
||||||
|
|
||||||
|
class MetricsRegistry:
|
||||||
|
"""Registry for managing Prometheus metric providers and collectors."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._metrics: dict[str, Metric] = {}
|
||||||
|
self._instances: dict[str, Any] = {}
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self,
|
||||||
|
func: Callable[..., Any] | None = None,
|
||||||
|
*,
|
||||||
|
name: str | None = None,
|
||||||
|
collect: bool = False,
|
||||||
|
) -> Callable[..., Any]:
|
||||||
|
"""Register a metric provider or collector function.
|
||||||
|
|
||||||
|
Can be used as a decorator with or without arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The metric function to register.
|
||||||
|
name: Metric name (defaults to function name).
|
||||||
|
collect: If ``True``, the function is called on every scrape.
|
||||||
|
If ``False`` (default), called once at init time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
|
metric_name = name or cast(Any, fn).__name__
|
||||||
|
self._metrics[metric_name] = Metric(
|
||||||
|
name=metric_name,
|
||||||
|
func=fn,
|
||||||
|
collect=collect,
|
||||||
|
)
|
||||||
|
return fn
|
||||||
|
|
||||||
|
if func is not None:
|
||||||
|
return decorator(func)
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def get(self, name: str) -> Any:
|
||||||
|
"""Return the metric instance created by a provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The metric name (defaults to the provider function name).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the metric name is unknown or ``init_metrics`` has not
|
||||||
|
been called yet.
|
||||||
|
"""
|
||||||
|
if name not in self._instances:
|
||||||
|
if name in self._metrics:
|
||||||
|
raise KeyError(
|
||||||
|
f"Metric '{name}' exists but has not been initialized yet. "
|
||||||
|
"Ensure init_metrics() has been called before accessing metric instances."
|
||||||
|
)
|
||||||
|
raise KeyError(f"Unknown metric '{name}'.")
|
||||||
|
return self._instances[name]
|
||||||
|
|
||||||
|
def include_registry(self, registry: "MetricsRegistry") -> None:
|
||||||
|
"""Include another :class:`MetricsRegistry` into this one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registry: The registry to merge in.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If a metric name already exists in the current registry.
|
||||||
|
"""
|
||||||
|
for metric_name, definition in registry._metrics.items():
|
||||||
|
if metric_name in self._metrics:
|
||||||
|
raise ValueError(
|
||||||
|
f"Metric '{metric_name}' already exists in the current registry"
|
||||||
|
)
|
||||||
|
self._metrics[metric_name] = definition
|
||||||
|
|
||||||
|
def get_all(self) -> list[Metric]:
|
||||||
|
"""Get all registered metric definitions."""
|
||||||
|
return list(self._metrics.values())
|
||||||
|
|
||||||
|
def get_providers(self) -> list[Metric]:
|
||||||
|
"""Get metric providers (called once at init)."""
|
||||||
|
return [m for m in self._metrics.values() if not m.collect]
|
||||||
|
|
||||||
|
def get_collectors(self) -> list[Metric]:
|
||||||
|
"""Get collectors (called on each scrape)."""
|
||||||
|
return [m for m in self._metrics.values() if m.collect]
|
||||||
21
src/fastapi_toolsets/models/__init__.py
Normal file
21
src/fastapi_toolsets/models/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""SQLAlchemy model mixins for common column patterns."""
|
||||||
|
|
||||||
|
from .columns import (
|
||||||
|
CreatedAtMixin,
|
||||||
|
TimestampMixin,
|
||||||
|
UUIDMixin,
|
||||||
|
UUIDv7Mixin,
|
||||||
|
UpdatedAtMixin,
|
||||||
|
)
|
||||||
|
from .watched import EventSession, ModelEvent, listens_for
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"EventSession",
|
||||||
|
"ModelEvent",
|
||||||
|
"UUIDMixin",
|
||||||
|
"UUIDv7Mixin",
|
||||||
|
"CreatedAtMixin",
|
||||||
|
"UpdatedAtMixin",
|
||||||
|
"TimestampMixin",
|
||||||
|
"listens_for",
|
||||||
|
]
|
||||||
50
src/fastapi_toolsets/models/columns.py
Normal file
50
src/fastapi_toolsets/models/columns.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""SQLAlchemy column mixins for common column patterns."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, Uuid, text
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
|
||||||
|
class UUIDMixin:
|
||||||
|
"""Mixin that adds a UUID primary key auto-generated by the database."""
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
Uuid,
|
||||||
|
primary_key=True,
|
||||||
|
server_default=text("gen_random_uuid()"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UUIDv7Mixin:
|
||||||
|
"""Mixin that adds a UUIDv7 primary key auto-generated by the database."""
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
Uuid,
|
||||||
|
primary_key=True,
|
||||||
|
server_default=text("uuidv7()"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CreatedAtMixin:
|
||||||
|
"""Mixin that adds a ``created_at`` timestamp column."""
|
||||||
|
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
server_default=text("clock_timestamp()"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UpdatedAtMixin:
|
||||||
|
"""Mixin that adds an ``updated_at`` timestamp column."""
|
||||||
|
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
server_default=text("clock_timestamp()"),
|
||||||
|
onupdate=text("clock_timestamp()"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TimestampMixin(CreatedAtMixin, UpdatedAtMixin):
|
||||||
|
"""Mixin that combines ``created_at`` and ``updated_at`` timestamp columns."""
|
||||||
290
src/fastapi_toolsets/models/watched.py
Normal file
290
src/fastapi_toolsets/models/watched.py
Normal file
@@ -0,0 +1,290 @@
|
|||||||
|
"""Field-change monitoring via SQLAlchemy session events."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from collections.abc import Callable
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import event
|
||||||
|
from sqlalchemy import inspect as sa_inspect
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm.attributes import set_committed_value as _sa_set_committed_value
|
||||||
|
|
||||||
|
from ..logger import get_logger
|
||||||
|
|
||||||
|
_logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class ModelEvent(str, Enum):
|
||||||
|
"""Event types dispatched by :class:`EventSession`."""
|
||||||
|
|
||||||
|
CREATE = "create"
|
||||||
|
DELETE = "delete"
|
||||||
|
UPDATE = "update"
|
||||||
|
|
||||||
|
|
||||||
|
_CALLBACK_ERROR_MSG = "Event callback raised an unhandled exception"
|
||||||
|
_SESSION_CREATES = "_ft_creates"
|
||||||
|
_SESSION_DELETES = "_ft_deletes"
|
||||||
|
_SESSION_UPDATES = "_ft_updates"
|
||||||
|
_DEFERRED_STRATEGY_KEY = (("deferred", True), ("instrument", True))
|
||||||
|
_EVENT_HANDLERS: dict[tuple[type, ModelEvent], list[Callable[..., Any]]] = {}
|
||||||
|
_WATCHED_MODELS: set[type] = set()
|
||||||
|
_WATCHED_CACHE: dict[type, bool] = {}
|
||||||
|
_HANDLER_CACHE: dict[tuple[type, ModelEvent], list[Callable[..., Any]]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _invalidate_caches() -> None:
|
||||||
|
"""Clear lookup caches after handler registration."""
|
||||||
|
_WATCHED_CACHE.clear()
|
||||||
|
_HANDLER_CACHE.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def listens_for(
|
||||||
|
model_class: type,
|
||||||
|
event_types: list[ModelEvent] | None = None,
|
||||||
|
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||||
|
"""Register a callback for one or more model lifecycle events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_class: The SQLAlchemy model class to listen on.
|
||||||
|
event_types: List of :class:`ModelEvent` values to listen for.
|
||||||
|
Defaults to all event types.
|
||||||
|
"""
|
||||||
|
evs = event_types if event_types is not None else list(ModelEvent)
|
||||||
|
|
||||||
|
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
|
for ev in evs:
|
||||||
|
_EVENT_HANDLERS.setdefault((model_class, ev), []).append(fn)
|
||||||
|
_WATCHED_MODELS.add(model_class)
|
||||||
|
_invalidate_caches()
|
||||||
|
return fn
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def _is_watched(obj: Any) -> bool:
|
||||||
|
"""Return True if *obj*'s type (or any ancestor) has registered handlers."""
|
||||||
|
cls = type(obj)
|
||||||
|
try:
|
||||||
|
return _WATCHED_CACHE[cls]
|
||||||
|
except KeyError:
|
||||||
|
result = any(klass in _WATCHED_MODELS for klass in cls.__mro__)
|
||||||
|
_WATCHED_CACHE[cls] = result
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _get_handlers(cls: type, ev: ModelEvent) -> list[Callable[..., Any]]:
|
||||||
|
"""Return registered handlers for *cls* and *ev*, walking the MRO."""
|
||||||
|
key = (cls, ev)
|
||||||
|
try:
|
||||||
|
return _HANDLER_CACHE[key]
|
||||||
|
except KeyError:
|
||||||
|
handlers: list[Callable[..., Any]] = []
|
||||||
|
for klass in cls.__mro__:
|
||||||
|
handlers.extend(_EVENT_HANDLERS.get((klass, ev), []))
|
||||||
|
_HANDLER_CACHE[key] = handlers
|
||||||
|
return handlers
|
||||||
|
|
||||||
|
|
||||||
|
def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
|
||||||
|
"""Read currently-loaded column values into a plain dict."""
|
||||||
|
state = sa_inspect(obj) # InstanceState
|
||||||
|
state_dict = state.dict
|
||||||
|
snapshot: dict[str, Any] = {}
|
||||||
|
for prop in state.mapper.column_attrs:
|
||||||
|
if prop.key in state_dict:
|
||||||
|
snapshot[prop.key] = state_dict[prop.key]
|
||||||
|
elif ( # pragma: no cover
|
||||||
|
not state.expired
|
||||||
|
and prop.strategy_key != _DEFERRED_STRATEGY_KEY
|
||||||
|
and all(
|
||||||
|
col.nullable
|
||||||
|
and col.server_default is None
|
||||||
|
and col.server_onupdate is None
|
||||||
|
for col in prop.columns
|
||||||
|
)
|
||||||
|
):
|
||||||
|
snapshot[prop.key] = None
|
||||||
|
return snapshot
|
||||||
|
|
||||||
|
|
||||||
|
def _get_watched_fields(cls: type) -> tuple[str, ...] | None:
|
||||||
|
"""Return the watched fields for *cls*."""
|
||||||
|
fields = getattr(cls, "__watched_fields__", None)
|
||||||
|
if fields is not None and (
|
||||||
|
not isinstance(fields, tuple) or not all(isinstance(f, str) for f in fields)
|
||||||
|
):
|
||||||
|
raise TypeError(
|
||||||
|
f"{cls.__name__}.__watched_fields__ must be a tuple[str, ...], "
|
||||||
|
f"got {type(fields).__name__}"
|
||||||
|
)
|
||||||
|
return fields
|
||||||
|
|
||||||
|
|
||||||
|
def _upsert_changes(
|
||||||
|
pending: dict[int, tuple[Any, dict[str, dict[str, Any]]]],
|
||||||
|
obj: Any,
|
||||||
|
changes: dict[str, dict[str, Any]],
|
||||||
|
) -> None:
|
||||||
|
"""Insert or merge *changes* into *pending* for *obj*."""
|
||||||
|
key = id(obj)
|
||||||
|
if key in pending:
|
||||||
|
existing = pending[key][1]
|
||||||
|
for field, change in changes.items():
|
||||||
|
if field in existing:
|
||||||
|
existing[field]["new"] = change["new"]
|
||||||
|
else:
|
||||||
|
existing[field] = change
|
||||||
|
else:
|
||||||
|
pending[key] = (obj, changes)
|
||||||
|
|
||||||
|
|
||||||
|
@event.listens_for(AsyncSession.sync_session_class, "after_flush")
|
||||||
|
def _after_flush(session: Any, flush_context: Any) -> None:
|
||||||
|
# New objects: capture reference. Attributes will be refreshed after commit.
|
||||||
|
for obj in session.new:
|
||||||
|
if _is_watched(obj):
|
||||||
|
session.info.setdefault(_SESSION_CREATES, []).append(obj)
|
||||||
|
|
||||||
|
# Deleted objects: snapshot now while attributes are still loaded.
|
||||||
|
for obj in session.deleted:
|
||||||
|
if _is_watched(obj):
|
||||||
|
snapshot = _snapshot_column_attrs(obj)
|
||||||
|
session.info.setdefault(_SESSION_DELETES, []).append((obj, snapshot))
|
||||||
|
|
||||||
|
# Dirty objects: read old/new from SQLAlchemy attribute history.
|
||||||
|
for obj in session.dirty:
|
||||||
|
if not _is_watched(obj):
|
||||||
|
continue
|
||||||
|
|
||||||
|
watched = _get_watched_fields(type(obj))
|
||||||
|
changes: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
|
inst_attrs = sa_inspect(obj).attrs
|
||||||
|
attrs = (
|
||||||
|
((field, inst_attrs[field]) for field in watched)
|
||||||
|
if watched is not None
|
||||||
|
else ((s.key, s) for s in inst_attrs)
|
||||||
|
)
|
||||||
|
for field, attr_state in attrs:
|
||||||
|
history = attr_state.history
|
||||||
|
if history.has_changes() and history.deleted:
|
||||||
|
changes[field] = {
|
||||||
|
"old": history.deleted[0],
|
||||||
|
"new": history.added[0] if history.added else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
if changes:
|
||||||
|
_upsert_changes(
|
||||||
|
session.info.setdefault(_SESSION_UPDATES, {}),
|
||||||
|
obj,
|
||||||
|
changes,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@event.listens_for(AsyncSession.sync_session_class, "after_rollback")
|
||||||
|
def _after_rollback(session: Any) -> None:
|
||||||
|
if session.in_transaction():
|
||||||
|
return
|
||||||
|
session.info.pop(_SESSION_CREATES, None)
|
||||||
|
session.info.pop(_SESSION_DELETES, None)
|
||||||
|
session.info.pop(_SESSION_UPDATES, None)
|
||||||
|
|
||||||
|
|
||||||
|
async def _invoke_callback(
|
||||||
|
fn: Callable[..., Any],
|
||||||
|
obj: Any,
|
||||||
|
event_type: ModelEvent,
|
||||||
|
changes: dict[str, dict[str, Any]] | None,
|
||||||
|
) -> None:
|
||||||
|
"""Call *fn* and await the result if it is awaitable."""
|
||||||
|
result = fn(obj, event_type, changes)
|
||||||
|
if inspect.isawaitable(result):
|
||||||
|
await result
|
||||||
|
|
||||||
|
|
||||||
|
class EventSession(AsyncSession):
|
||||||
|
"""AsyncSession subclass that dispatches lifecycle callbacks after commit."""
|
||||||
|
|
||||||
|
async def commit(self) -> None: # noqa: C901
|
||||||
|
await super().commit()
|
||||||
|
|
||||||
|
creates: list[Any] = self.info.pop(_SESSION_CREATES, [])
|
||||||
|
deletes: list[tuple[Any, dict[str, Any]]] = self.info.pop(_SESSION_DELETES, [])
|
||||||
|
field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = self.info.pop(
|
||||||
|
_SESSION_UPDATES, {}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not creates and not deletes and not field_changes:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Suppress transient objects (created + deleted in same transaction).
|
||||||
|
if creates and deletes:
|
||||||
|
created_ids = {id(o) for o in creates}
|
||||||
|
deleted_ids = {id(o) for o, _ in deletes}
|
||||||
|
transient_ids = created_ids & deleted_ids
|
||||||
|
if transient_ids:
|
||||||
|
creates = [o for o in creates if id(o) not in transient_ids]
|
||||||
|
deletes = [(o, s) for o, s in deletes if id(o) not in transient_ids]
|
||||||
|
field_changes = {
|
||||||
|
k: v for k, v in field_changes.items() if k not in transient_ids
|
||||||
|
}
|
||||||
|
|
||||||
|
# Suppress updates for deleted objects (row is gone, refresh would fail).
|
||||||
|
if deletes and field_changes:
|
||||||
|
deleted_ids = {id(o) for o, _ in deletes}
|
||||||
|
field_changes = {
|
||||||
|
k: v for k, v in field_changes.items() if k not in deleted_ids
|
||||||
|
}
|
||||||
|
|
||||||
|
# Suppress updates for newly created objects (CREATE-only semantics).
|
||||||
|
if creates and field_changes:
|
||||||
|
create_ids = {id(o) for o in creates}
|
||||||
|
field_changes = {
|
||||||
|
k: v for k, v in field_changes.items() if k not in create_ids
|
||||||
|
}
|
||||||
|
|
||||||
|
# Dispatch CREATE callbacks.
|
||||||
|
for obj in creates:
|
||||||
|
try:
|
||||||
|
state = sa_inspect(obj, raiseerr=False)
|
||||||
|
if (
|
||||||
|
state is None or state.detached or state.transient
|
||||||
|
): # pragma: no cover
|
||||||
|
continue
|
||||||
|
await self.refresh(obj)
|
||||||
|
for handler in _get_handlers(type(obj), ModelEvent.CREATE):
|
||||||
|
await _invoke_callback(handler, obj, ModelEvent.CREATE, None)
|
||||||
|
except Exception as exc:
|
||||||
|
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||||
|
|
||||||
|
# Dispatch DELETE callbacks (restore snapshot; row is gone).
|
||||||
|
for obj, snapshot in deletes:
|
||||||
|
try:
|
||||||
|
for key, value in snapshot.items():
|
||||||
|
_sa_set_committed_value(obj, key, value)
|
||||||
|
for handler in _get_handlers(type(obj), ModelEvent.DELETE):
|
||||||
|
await _invoke_callback(handler, obj, ModelEvent.DELETE, None)
|
||||||
|
except Exception as exc:
|
||||||
|
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||||
|
|
||||||
|
# Dispatch UPDATE callbacks.
|
||||||
|
for obj, changes in field_changes.values():
|
||||||
|
try:
|
||||||
|
state = sa_inspect(obj, raiseerr=False)
|
||||||
|
if (
|
||||||
|
state is None or state.detached or state.transient
|
||||||
|
): # pragma: no cover
|
||||||
|
continue
|
||||||
|
await self.refresh(obj)
|
||||||
|
for handler in _get_handlers(type(obj), ModelEvent.UPDATE):
|
||||||
|
await _invoke_callback(handler, obj, ModelEvent.UPDATE, changes)
|
||||||
|
except Exception as exc:
|
||||||
|
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||||
|
|
||||||
|
async def rollback(self) -> None:
|
||||||
|
await super().rollback()
|
||||||
|
self.info.pop(_SESSION_CREATES, None)
|
||||||
|
self.info.pop(_SESSION_DELETES, None)
|
||||||
|
self.info.pop(_SESSION_UPDATES, None)
|
||||||
30
src/fastapi_toolsets/pytest/__init__.py
Normal file
30
src/fastapi_toolsets/pytest/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""Pytest helpers for FastAPI testing: sessions, clients, and fixtures."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .plugin import register_fixtures
|
||||||
|
except ImportError:
|
||||||
|
from .._imports import require_extra
|
||||||
|
|
||||||
|
require_extra(package="pytest", extra="pytest")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .utils import (
|
||||||
|
cleanup_tables,
|
||||||
|
create_async_client,
|
||||||
|
create_db_session,
|
||||||
|
create_worker_database,
|
||||||
|
worker_database_url,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
from .._imports import require_extra
|
||||||
|
|
||||||
|
require_extra(package="httpx", extra="pytest")
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"cleanup_tables",
|
||||||
|
"create_async_client",
|
||||||
|
"create_db_session",
|
||||||
|
"create_worker_database",
|
||||||
|
"register_fixtures",
|
||||||
|
"worker_database_url",
|
||||||
|
]
|
||||||
@@ -1,55 +1,4 @@
|
|||||||
"""Pytest plugin for using FixtureRegistry fixtures in tests.
|
"""Pytest plugin for using FixtureRegistry fixtures in tests."""
|
||||||
|
|
||||||
This module provides utilities to automatically generate pytest fixtures
|
|
||||||
from your FixtureRegistry, with proper dependency resolution.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# conftest.py
|
|
||||||
import pytest
|
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
|
||||||
|
|
||||||
from app.fixtures import fixtures # Your FixtureRegistry
|
|
||||||
from app.models import Base
|
|
||||||
from fastapi_toolsets.pytest_plugin import register_fixtures
|
|
||||||
|
|
||||||
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/test_db"
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def engine():
|
|
||||||
engine = create_async_engine(DATABASE_URL)
|
|
||||||
yield engine
|
|
||||||
await engine.dispose()
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def db_session(engine):
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
|
||||||
|
|
||||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
||||||
session = session_factory()
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield session
|
|
||||||
finally:
|
|
||||||
await session.close()
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
await conn.run_sync(Base.metadata.drop_all)
|
|
||||||
|
|
||||||
# Automatically generate pytest fixtures from registry
|
|
||||||
# Creates: fixture_roles, fixture_users, fixture_posts, etc.
|
|
||||||
register_fixtures(fixtures, globals())
|
|
||||||
|
|
||||||
Usage in tests:
|
|
||||||
# test_users.py
|
|
||||||
async def test_user_count(db_session, fixture_users):
|
|
||||||
# fixture_users automatically loads fixture_roles first (if dependency)
|
|
||||||
# and returns the list of User models
|
|
||||||
assert len(fixture_users) > 0
|
|
||||||
|
|
||||||
async def test_user_role(db_session, fixture_users):
|
|
||||||
user = fixture_users[0]
|
|
||||||
assert user.role_id is not None
|
|
||||||
"""
|
|
||||||
|
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -59,7 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
from ..db import get_transaction
|
from ..db import get_transaction
|
||||||
from .fixtures import FixtureRegistry, LoadStrategy
|
from ..fixtures import FixtureRegistry, LoadStrategy
|
||||||
|
|
||||||
|
|
||||||
def register_fixtures(
|
def register_fixtures(
|
||||||
@@ -86,6 +35,7 @@ def register_fixtures(
|
|||||||
List of created fixture names
|
List of created fixture names
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
# conftest.py
|
# conftest.py
|
||||||
from app.fixtures import fixtures
|
from app.fixtures import fixtures
|
||||||
from fastapi_toolsets.pytest_plugin import register_fixtures
|
from fastapi_toolsets.pytest_plugin import register_fixtures
|
||||||
@@ -96,6 +46,7 @@ def register_fixtures(
|
|||||||
# - fixture_roles
|
# - fixture_roles
|
||||||
# - fixture_users (depends on fixture_roles if users depends on roles)
|
# - fixture_users (depends on fixture_roles if users depends on roles)
|
||||||
# - fixture_posts (depends on fixture_users if posts depends on users)
|
# - fixture_posts (depends on fixture_users if posts depends on users)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
created_fixtures: list[str] = []
|
created_fixtures: list[str] = []
|
||||||
|
|
||||||
260
src/fastapi_toolsets/pytest/utils.py
Normal file
260
src/fastapi_toolsets/pytest/utils.py
Normal file
@@ -0,0 +1,260 @@
|
|||||||
|
"""Pytest helper utilities for FastAPI testing."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from collections.abc import AsyncGenerator, Callable
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.engine import make_url
|
||||||
|
from sqlalchemy.ext.asyncio import (
|
||||||
|
AsyncSession,
|
||||||
|
async_sessionmaker,
|
||||||
|
create_async_engine,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from ..db import cleanup_tables, create_database
|
||||||
|
from ..models.watched import EventSession
|
||||||
|
|
||||||
|
|
||||||
|
def _get_xdist_worker(default_test_db: str) -> str:
|
||||||
|
"""Return the pytest-xdist worker name, or *default_test_db* when not running under xdist.
|
||||||
|
|
||||||
|
Reads the ``PYTEST_XDIST_WORKER`` environment variable that xdist sets
|
||||||
|
automatically in each worker process (e.g. ``"gw0"``, ``"gw1"``).
|
||||||
|
When xdist is not installed or not active, the variable is absent and
|
||||||
|
*default_test_db* is returned instead.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
default_test_db: Fallback value returned when ``PYTEST_XDIST_WORKER``
|
||||||
|
is not set.
|
||||||
|
"""
|
||||||
|
return os.environ.get("PYTEST_XDIST_WORKER", default_test_db)
|
||||||
|
|
||||||
|
|
||||||
|
def worker_database_url(database_url: str, default_test_db: str) -> str:
|
||||||
|
"""Derive a per-worker database URL for pytest-xdist parallel runs.
|
||||||
|
|
||||||
|
Appends ``_{worker_name}`` to the database name so each xdist worker
|
||||||
|
operates on its own database. When not running under xdist,
|
||||||
|
``_{default_test_db}`` is appended instead.
|
||||||
|
|
||||||
|
The worker name is read from the ``PYTEST_XDIST_WORKER`` environment
|
||||||
|
variable (set automatically by xdist in each worker process).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
database_url: Original database connection URL.
|
||||||
|
default_test_db: Suffix appended to the database name when
|
||||||
|
``PYTEST_XDIST_WORKER`` is not set.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A database URL with a worker- or default-specific database name.
|
||||||
|
"""
|
||||||
|
worker = _get_xdist_worker(default_test_db=default_test_db)
|
||||||
|
|
||||||
|
url = make_url(database_url)
|
||||||
|
url = url.set(database=f"{url.database}_{worker}")
|
||||||
|
return url.render_as_string(hide_password=False)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def create_worker_database(
|
||||||
|
database_url: str,
|
||||||
|
default_test_db: str = "test_db",
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Create and drop a per-worker database for pytest-xdist isolation.
|
||||||
|
|
||||||
|
Derives a worker-specific database URL using :func:`worker_database_url`,
|
||||||
|
then delegates to :func:`~fastapi_toolsets.db.create_database` to create
|
||||||
|
and drop it. Intended for use as a **session-scoped** fixture.
|
||||||
|
|
||||||
|
When running under xdist the database name is suffixed with the worker
|
||||||
|
name (e.g. ``_gw0``). Otherwise it is suffixed with *default_test_db*.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
database_url: Original database connection URL (used as the server
|
||||||
|
connection and as the base for the worker database name).
|
||||||
|
default_test_db: Suffix appended to the database name when
|
||||||
|
``PYTEST_XDIST_WORKER`` is not set. Defaults to ``"test_db"``.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The worker-specific database URL.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import create_worker_database, create_db_session
|
||||||
|
|
||||||
|
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost/test_db"
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
async def worker_db_url():
|
||||||
|
async with create_worker_database(DATABASE_URL) as url:
|
||||||
|
yield url
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session(worker_db_url):
|
||||||
|
async with create_db_session(
|
||||||
|
worker_db_url, Base, cleanup=True
|
||||||
|
) as session:
|
||||||
|
yield session
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
worker_url = worker_database_url(
|
||||||
|
database_url=database_url, default_test_db=default_test_db
|
||||||
|
)
|
||||||
|
worker_db_name = make_url(worker_url).database
|
||||||
|
assert worker_db_name is not None
|
||||||
|
|
||||||
|
engine = create_async_engine(database_url, isolation_level="AUTOCOMMIT")
|
||||||
|
try:
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
await conn.execute(text(f"DROP DATABASE IF EXISTS {worker_db_name}"))
|
||||||
|
await create_database(db_name=worker_db_name, server_url=database_url)
|
||||||
|
|
||||||
|
yield worker_url
|
||||||
|
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
await conn.execute(text(f"DROP DATABASE IF EXISTS {worker_db_name}"))
|
||||||
|
finally:
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def create_async_client(
|
||||||
|
app: Any,
|
||||||
|
base_url: str = "http://test",
|
||||||
|
dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] | None = None,
|
||||||
|
) -> AsyncGenerator[AsyncClient, None]:
|
||||||
|
"""Create an async httpx client for testing FastAPI applications.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI application instance.
|
||||||
|
base_url: Base URL for requests. Defaults to "http://test".
|
||||||
|
dependency_overrides: Optional mapping of original dependencies to
|
||||||
|
their test replacements. Applied via ``app.dependency_overrides``
|
||||||
|
before yielding and cleaned up after.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
An AsyncClient configured for the app.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi_toolsets.pytest import create_async_client
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def client():
|
||||||
|
async with create_async_client(app) as c:
|
||||||
|
yield c
|
||||||
|
|
||||||
|
async def test_endpoint(client: AsyncClient):
|
||||||
|
response = await client.get("/health")
|
||||||
|
assert response.status_code == 200
|
||||||
|
```
|
||||||
|
|
||||||
|
Example with dependency overrides:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import create_async_client, create_db_session
|
||||||
|
from app.db import get_db
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session():
|
||||||
|
async with create_db_session(DATABASE_URL, Base, cleanup=True) as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def client(db_session):
|
||||||
|
async def override():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
async with create_async_client(
|
||||||
|
app, dependency_overrides={get_db: override}
|
||||||
|
) as c:
|
||||||
|
yield c
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if dependency_overrides:
|
||||||
|
app.dependency_overrides.update(dependency_overrides)
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
try:
|
||||||
|
async with AsyncClient(transport=transport, base_url=base_url) as client:
|
||||||
|
yield client
|
||||||
|
finally:
|
||||||
|
if dependency_overrides:
|
||||||
|
for key in dependency_overrides:
|
||||||
|
app.dependency_overrides.pop(key, None)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def create_db_session(
|
||||||
|
database_url: str,
|
||||||
|
base: type[DeclarativeBase],
|
||||||
|
*,
|
||||||
|
echo: bool = False,
|
||||||
|
expire_on_commit: bool = False,
|
||||||
|
drop_tables: bool = True,
|
||||||
|
cleanup: bool = False,
|
||||||
|
) -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""Create a database session for testing.
|
||||||
|
|
||||||
|
Creates tables before yielding the session and optionally drops them after.
|
||||||
|
Each call creates a fresh engine and session for test isolation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
database_url: Database connection URL (e.g., "postgresql+asyncpg://...").
|
||||||
|
base: SQLAlchemy DeclarativeBase class containing model metadata.
|
||||||
|
echo: Enable SQLAlchemy query logging. Defaults to False.
|
||||||
|
expire_on_commit: Expire objects after commit. Defaults to False.
|
||||||
|
drop_tables: Drop tables after test. Defaults to True.
|
||||||
|
cleanup: Truncate all tables after test using
|
||||||
|
:func:`cleanup_tables`. Defaults to False.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
An AsyncSession ready for database operations.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import create_db_session
|
||||||
|
from app.models import Base
|
||||||
|
|
||||||
|
DATABASE_URL = "postgresql+asyncpg://user:pass@localhost/test_db"
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session():
|
||||||
|
async with create_db_session(
|
||||||
|
DATABASE_URL, Base, cleanup=True
|
||||||
|
) as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
async def test_create_user(db_session: AsyncSession):
|
||||||
|
user = User(name="test")
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
engine = create_async_engine(database_url, echo=echo)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create tables
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(base.metadata.create_all)
|
||||||
|
|
||||||
|
session_maker = async_sessionmaker(
|
||||||
|
engine, expire_on_commit=expire_on_commit, class_=EventSession
|
||||||
|
)
|
||||||
|
async with session_maker() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
if cleanup:
|
||||||
|
await cleanup_tables(session=session, base=base)
|
||||||
|
|
||||||
|
if drop_tables:
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(base.metadata.drop_all)
|
||||||
|
finally:
|
||||||
|
await engine.dispose()
|
||||||
@@ -1,21 +1,27 @@
|
|||||||
"""Base Pydantic schemas for API responses."""
|
"""Base Pydantic schemas for API responses."""
|
||||||
|
|
||||||
|
import math
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import ClassVar, Generic, TypeVar
|
from typing import Annotated, Any, ClassVar, Generic, Literal, TypeVar, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict, Field, computed_field
|
||||||
|
|
||||||
|
from .types import DataT
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ApiError",
|
"ApiError",
|
||||||
|
"CursorPagination",
|
||||||
|
"CursorPaginatedResponse",
|
||||||
"ErrorResponse",
|
"ErrorResponse",
|
||||||
"Pagination",
|
"OffsetPagination",
|
||||||
|
"OffsetPaginatedResponse",
|
||||||
"PaginatedResponse",
|
"PaginatedResponse",
|
||||||
|
"PaginationType",
|
||||||
|
"PydanticBase",
|
||||||
"Response",
|
"Response",
|
||||||
"ResponseStatus",
|
"ResponseStatus",
|
||||||
]
|
]
|
||||||
|
|
||||||
DataT = TypeVar("DataT")
|
|
||||||
|
|
||||||
|
|
||||||
class PydanticBase(BaseModel):
|
class PydanticBase(BaseModel):
|
||||||
"""Base class for all Pydantic models with common configuration."""
|
"""Base class for all Pydantic models with common configuration."""
|
||||||
@@ -49,6 +55,7 @@ class ApiError(PydanticBase):
|
|||||||
msg: str
|
msg: str
|
||||||
desc: str
|
desc: str
|
||||||
err_code: str
|
err_code: str
|
||||||
|
data: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
class BaseResponse(PydanticBase):
|
class BaseResponse(PydanticBase):
|
||||||
@@ -69,7 +76,9 @@ class Response(BaseResponse, Generic[DataT]):
|
|||||||
"""Generic API response with data payload.
|
"""Generic API response with data payload.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
Response[UserRead](data=user, message="User retrieved")
|
Response[UserRead](data=user, message="User retrieved")
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: DataT | None = None
|
data: DataT | None = None
|
||||||
@@ -83,34 +92,115 @@ class ErrorResponse(BaseResponse):
|
|||||||
|
|
||||||
status: ResponseStatus = ResponseStatus.FAIL
|
status: ResponseStatus = ResponseStatus.FAIL
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
data: None = None
|
data: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
class Pagination(PydanticBase):
|
class OffsetPagination(PydanticBase):
|
||||||
"""Pagination metadata for list responses.
|
"""Pagination metadata for offset-based list responses.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
total_count: Total number of items across all pages
|
total_count: Total number of items across all pages.
|
||||||
|
``None`` when ``include_total=False``.
|
||||||
items_per_page: Number of items per page
|
items_per_page: Number of items per page
|
||||||
page: Current page number (1-indexed)
|
page: Current page number (1-indexed)
|
||||||
has_more: Whether there are more pages
|
has_more: Whether there are more pages
|
||||||
|
pages: Total number of pages
|
||||||
"""
|
"""
|
||||||
|
|
||||||
total_count: int
|
total_count: int | None
|
||||||
items_per_page: int
|
items_per_page: int
|
||||||
page: int
|
page: int
|
||||||
has_more: bool
|
has_more: bool
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def pages(self) -> int | None:
|
||||||
|
"""Total number of pages, or ``None`` when ``total_count`` is unknown."""
|
||||||
|
if self.total_count is None:
|
||||||
|
return None
|
||||||
|
if self.items_per_page == 0:
|
||||||
|
return 0
|
||||||
|
return math.ceil(self.total_count / self.items_per_page)
|
||||||
|
|
||||||
|
|
||||||
|
class CursorPagination(PydanticBase):
|
||||||
|
"""Pagination metadata for cursor-based list responses.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
next_cursor: Encoded cursor for the next page, or None on the last page.
|
||||||
|
prev_cursor: Encoded cursor for the previous page, or None on the first page.
|
||||||
|
items_per_page: Number of items requested per page.
|
||||||
|
has_more: Whether there is at least one more page after this one.
|
||||||
|
"""
|
||||||
|
|
||||||
|
next_cursor: str | None
|
||||||
|
prev_cursor: str | None = None
|
||||||
|
items_per_page: int
|
||||||
|
has_more: bool
|
||||||
|
|
||||||
|
|
||||||
|
class PaginationType(str, Enum):
|
||||||
|
"""Pagination strategy selector for :meth:`.AsyncCrud.paginate`."""
|
||||||
|
|
||||||
|
OFFSET = "offset"
|
||||||
|
CURSOR = "cursor"
|
||||||
|
|
||||||
|
|
||||||
class PaginatedResponse(BaseResponse, Generic[DataT]):
|
class PaginatedResponse(BaseResponse, Generic[DataT]):
|
||||||
"""Paginated API response for list endpoints.
|
"""Paginated API response for list endpoints.
|
||||||
|
|
||||||
Example:
|
Base class and return type for endpoints that support both pagination
|
||||||
PaginatedResponse[UserRead](
|
strategies. Use :class:`OffsetPaginatedResponse` or
|
||||||
data=users,
|
:class:`CursorPaginatedResponse` when the strategy is fixed.
|
||||||
pagination=Pagination(total_count=100, items_per_page=10, page=1, has_more=True)
|
|
||||||
)
|
When used as ``PaginatedResponse[T]`` in a return annotation, subscripting
|
||||||
|
returns ``Annotated[Union[CursorPaginatedResponse[T], OffsetPaginatedResponse[T]], Field(discriminator="pagination_type")]``
|
||||||
|
so FastAPI emits a proper ``oneOf`` + discriminator in the OpenAPI schema.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: list[DataT]
|
data: list[DataT]
|
||||||
pagination: Pagination
|
pagination: OffsetPagination | CursorPagination
|
||||||
|
pagination_type: PaginationType | None = None
|
||||||
|
filter_attributes: dict[str, list[Any]] | None = None
|
||||||
|
search_columns: list[str] | None = None
|
||||||
|
sort_columns: list[str] | None = None
|
||||||
|
|
||||||
|
_discriminated_union_cache: ClassVar[dict[Any, Any]] = {}
|
||||||
|
|
||||||
|
def __class_getitem__( # ty:ignore[invalid-method-override]
|
||||||
|
cls, item: type[Any] | tuple[type[Any], ...]
|
||||||
|
) -> type[Any]:
|
||||||
|
if cls is PaginatedResponse and not isinstance(item, TypeVar):
|
||||||
|
cached = cls._discriminated_union_cache.get(item)
|
||||||
|
if cached is None:
|
||||||
|
cached = Annotated[
|
||||||
|
Union[CursorPaginatedResponse[item], OffsetPaginatedResponse[item]], # ty:ignore[invalid-type-form]
|
||||||
|
Field(discriminator="pagination_type"),
|
||||||
|
]
|
||||||
|
cls._discriminated_union_cache[item] = cached
|
||||||
|
return cached # ty:ignore[invalid-return-type]
|
||||||
|
return super().__class_getitem__(item)
|
||||||
|
|
||||||
|
|
||||||
|
class OffsetPaginatedResponse(PaginatedResponse[DataT]):
|
||||||
|
"""Paginated response with typed offset-based pagination metadata.
|
||||||
|
|
||||||
|
The ``pagination_type`` field is always ``"offset"`` and acts as a
|
||||||
|
discriminator, allowing frontend clients to narrow the union type returned
|
||||||
|
by a unified ``paginate()`` endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pagination: OffsetPagination
|
||||||
|
pagination_type: Literal[PaginationType.OFFSET] = PaginationType.OFFSET
|
||||||
|
|
||||||
|
|
||||||
|
class CursorPaginatedResponse(PaginatedResponse[DataT]):
|
||||||
|
"""Paginated response with typed cursor-based pagination metadata.
|
||||||
|
|
||||||
|
The ``pagination_type`` field is always ``"cursor"`` and acts as a
|
||||||
|
discriminator, allowing frontend clients to narrow the union type returned
|
||||||
|
by a unified ``paginate()`` endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pagination: CursorPagination
|
||||||
|
pagination_type: Literal[PaginationType.CURSOR] = PaginationType.CURSOR
|
||||||
|
|||||||
24
src/fastapi_toolsets/security/__init__.py
Normal file
24
src/fastapi_toolsets/security/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
"""Authentication helpers for FastAPI using Security()."""
|
||||||
|
|
||||||
|
from .abc import AuthSource
|
||||||
|
from .oauth import (
|
||||||
|
oauth_build_authorization_redirect,
|
||||||
|
oauth_decode_state,
|
||||||
|
oauth_encode_state,
|
||||||
|
oauth_fetch_userinfo,
|
||||||
|
oauth_resolve_provider_urls,
|
||||||
|
)
|
||||||
|
from .sources import APIKeyHeaderAuth, BearerTokenAuth, CookieAuth, MultiAuth
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"APIKeyHeaderAuth",
|
||||||
|
"AuthSource",
|
||||||
|
"BearerTokenAuth",
|
||||||
|
"CookieAuth",
|
||||||
|
"MultiAuth",
|
||||||
|
"oauth_build_authorization_redirect",
|
||||||
|
"oauth_decode_state",
|
||||||
|
"oauth_encode_state",
|
||||||
|
"oauth_fetch_userinfo",
|
||||||
|
"oauth_resolve_provider_urls",
|
||||||
|
]
|
||||||
53
src/fastapi_toolsets/security/abc.py
Normal file
53
src/fastapi_toolsets/security/abc.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
"""Abstract base class for authentication sources."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from fastapi.security import SecurityScopes
|
||||||
|
|
||||||
|
from fastapi_toolsets.exceptions import UnauthorizedError
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_async(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
|
"""Wrap *fn* so it can always be awaited, caching the coroutine check at init time."""
|
||||||
|
if inspect.iscoroutinefunction(fn):
|
||||||
|
return fn
|
||||||
|
|
||||||
|
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class AuthSource(ABC):
|
||||||
|
"""Abstract base class for authentication sources."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Set up the default FastAPI dependency signature."""
|
||||||
|
source = self
|
||||||
|
|
||||||
|
async def _call(
|
||||||
|
request: Request,
|
||||||
|
security_scopes: SecurityScopes, # noqa: ARG001
|
||||||
|
) -> Any:
|
||||||
|
credential = await source.extract(request)
|
||||||
|
if credential is None:
|
||||||
|
raise UnauthorizedError()
|
||||||
|
return await source.authenticate(credential)
|
||||||
|
|
||||||
|
self._call_fn: Callable[..., Any] = _call
|
||||||
|
self.__signature__ = inspect.signature(_call)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def extract(self, request: Request) -> str | None:
|
||||||
|
"""Extract the raw credential from the request without validating."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def authenticate(self, credential: str) -> Any:
|
||||||
|
"""Validate a credential and return the authenticated identity."""
|
||||||
|
|
||||||
|
async def __call__(self, **kwargs: Any) -> Any:
|
||||||
|
"""FastAPI dependency dispatch."""
|
||||||
|
return await self._call_fn(**kwargs)
|
||||||
140
src/fastapi_toolsets/security/oauth.py
Normal file
140
src/fastapi_toolsets/security/oauth.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
"""OAuth 2.0 / OIDC helper utilities."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
from typing import Any
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
|
||||||
|
_discovery_cache: dict[str, dict] = {}
|
||||||
|
|
||||||
|
|
||||||
|
async def oauth_resolve_provider_urls(
|
||||||
|
discovery_url: str,
|
||||||
|
) -> tuple[str, str, str | None]:
|
||||||
|
"""Fetch the OIDC discovery document and return endpoint URLs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
discovery_url: URL of the provider's ``/.well-known/openid-configuration``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A ``(authorization_url, token_url, userinfo_url)`` tuple.
|
||||||
|
*userinfo_url* is ``None`` when the provider does not advertise one.
|
||||||
|
"""
|
||||||
|
if discovery_url not in _discovery_cache:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
resp = await client.get(discovery_url)
|
||||||
|
resp.raise_for_status()
|
||||||
|
_discovery_cache[discovery_url] = resp.json()
|
||||||
|
cfg = _discovery_cache[discovery_url]
|
||||||
|
return (
|
||||||
|
cfg["authorization_endpoint"],
|
||||||
|
cfg["token_endpoint"],
|
||||||
|
cfg.get("userinfo_endpoint"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def oauth_fetch_userinfo(
|
||||||
|
*,
|
||||||
|
token_url: str,
|
||||||
|
userinfo_url: str,
|
||||||
|
code: str,
|
||||||
|
client_id: str,
|
||||||
|
client_secret: str,
|
||||||
|
redirect_uri: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Exchange an authorization code for tokens and return the userinfo payload.
|
||||||
|
|
||||||
|
Performs the two-step OAuth 2.0 / OIDC token exchange:
|
||||||
|
|
||||||
|
1. POSTs the authorization *code* to *token_url* to obtain an access token.
|
||||||
|
2. GETs *userinfo_url* using that access token as a Bearer credential.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_url: Provider's token endpoint.
|
||||||
|
userinfo_url: Provider's userinfo endpoint.
|
||||||
|
code: Authorization code received from the provider's callback.
|
||||||
|
client_id: OAuth application client ID.
|
||||||
|
client_secret: OAuth application client secret.
|
||||||
|
redirect_uri: Redirect URI that was used in the authorization request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The JSON payload returned by the userinfo endpoint as a plain ``dict``.
|
||||||
|
"""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
token_resp = await client.post(
|
||||||
|
token_url,
|
||||||
|
data={
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"code": code,
|
||||||
|
"client_id": client_id,
|
||||||
|
"client_secret": client_secret,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
},
|
||||||
|
headers={"Accept": "application/json"},
|
||||||
|
)
|
||||||
|
token_resp.raise_for_status()
|
||||||
|
access_token = token_resp.json()["access_token"]
|
||||||
|
|
||||||
|
userinfo_resp = await client.get(
|
||||||
|
userinfo_url,
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
userinfo_resp.raise_for_status()
|
||||||
|
return userinfo_resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
def oauth_build_authorization_redirect(
|
||||||
|
authorization_url: str,
|
||||||
|
*,
|
||||||
|
client_id: str,
|
||||||
|
scopes: str,
|
||||||
|
redirect_uri: str,
|
||||||
|
destination: str,
|
||||||
|
) -> RedirectResponse:
|
||||||
|
"""Return an OAuth 2.0 authorization ``RedirectResponse``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
authorization_url: Provider's authorization endpoint.
|
||||||
|
client_id: OAuth application client ID.
|
||||||
|
scopes: Space-separated list of requested scopes.
|
||||||
|
redirect_uri: URI the provider should redirect back to after authorization.
|
||||||
|
destination: URL the user should be sent to after the full OAuth flow
|
||||||
|
completes (encoded as ``state``).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A :class:`~fastapi.responses.RedirectResponse` to the provider's
|
||||||
|
authorization page.
|
||||||
|
"""
|
||||||
|
params = urlencode(
|
||||||
|
{
|
||||||
|
"client_id": client_id,
|
||||||
|
"response_type": "code",
|
||||||
|
"scope": scopes,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"state": oauth_encode_state(destination),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return RedirectResponse(f"{authorization_url}?{params}")
|
||||||
|
|
||||||
|
|
||||||
|
def oauth_encode_state(url: str) -> str:
|
||||||
|
"""Base64url-encode a URL to embed as an OAuth ``state`` parameter."""
|
||||||
|
return base64.urlsafe_b64encode(url.encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def oauth_decode_state(state: str | None, *, fallback: str) -> str:
|
||||||
|
"""Decode a base64url OAuth ``state`` parameter.
|
||||||
|
|
||||||
|
Handles missing padding (some providers strip ``=``).
|
||||||
|
Returns *fallback* if *state* is absent, the literal string ``"null"``,
|
||||||
|
or cannot be decoded.
|
||||||
|
"""
|
||||||
|
if not state or state == "null":
|
||||||
|
return fallback
|
||||||
|
try:
|
||||||
|
padded = state + "=" * (4 - len(state) % 4)
|
||||||
|
return base64.urlsafe_b64decode(padded).decode()
|
||||||
|
except Exception:
|
||||||
|
return fallback
|
||||||
8
src/fastapi_toolsets/security/sources/__init__.py
Normal file
8
src/fastapi_toolsets/security/sources/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
"""Built-in authentication source implementations."""
|
||||||
|
|
||||||
|
from .header import APIKeyHeaderAuth
|
||||||
|
from .bearer import BearerTokenAuth
|
||||||
|
from .cookie import CookieAuth
|
||||||
|
from .multi import MultiAuth
|
||||||
|
|
||||||
|
__all__ = ["APIKeyHeaderAuth", "BearerTokenAuth", "CookieAuth", "MultiAuth"]
|
||||||
120
src/fastapi_toolsets/security/sources/bearer.py
Normal file
120
src/fastapi_toolsets/security/sources/bearer.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
"""Bearer token authentication source."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import secrets
|
||||||
|
from typing import Annotated, Any, Callable
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer, SecurityScopes
|
||||||
|
|
||||||
|
from fastapi_toolsets.exceptions import UnauthorizedError
|
||||||
|
|
||||||
|
from ..abc import AuthSource, _ensure_async
|
||||||
|
|
||||||
|
|
||||||
|
class BearerTokenAuth(AuthSource):
|
||||||
|
"""Bearer token authentication source.
|
||||||
|
|
||||||
|
Wraps :class:`fastapi.security.HTTPBearer` for OpenAPI documentation.
|
||||||
|
The validator is called as ``await validator(credential, **kwargs)``
|
||||||
|
where ``kwargs`` are the extra keyword arguments provided at instantiation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
validator: Sync or async callable that receives the credential and any
|
||||||
|
extra keyword arguments, and returns the authenticated identity
|
||||||
|
(e.g. a ``User`` model). Should raise
|
||||||
|
:class:`~fastapi_toolsets.exceptions.UnauthorizedError` on failure.
|
||||||
|
prefix: Optional token prefix (e.g. ``"user_"``). If set, only tokens
|
||||||
|
whose value starts with this prefix are matched. The prefix is
|
||||||
|
**kept** in the value passed to the validator — store and compare
|
||||||
|
tokens with their prefix included. Use :meth:`generate_token` to
|
||||||
|
create correctly-prefixed tokens. This enables multiple
|
||||||
|
``BearerTokenAuth`` instances in the same app (e.g. ``"user_"``
|
||||||
|
for user tokens, ``"org_"`` for org tokens).
|
||||||
|
**kwargs: Extra keyword arguments forwarded to the validator on every
|
||||||
|
call (e.g. ``role=Role.ADMIN``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
validator: Callable[..., Any],
|
||||||
|
*,
|
||||||
|
prefix: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
self._validator = _ensure_async(validator)
|
||||||
|
self._prefix = prefix
|
||||||
|
self._kwargs = kwargs
|
||||||
|
self._scheme = HTTPBearer(auto_error=False)
|
||||||
|
|
||||||
|
async def _call(
|
||||||
|
security_scopes: SecurityScopes, # noqa: ARG001
|
||||||
|
credentials: Annotated[
|
||||||
|
HTTPAuthorizationCredentials | None, Depends(self._scheme)
|
||||||
|
] = None,
|
||||||
|
) -> Any:
|
||||||
|
if credentials is None:
|
||||||
|
raise UnauthorizedError()
|
||||||
|
return await self._validate(credentials.credentials)
|
||||||
|
|
||||||
|
self._call_fn = _call
|
||||||
|
self.__signature__ = inspect.signature(_call)
|
||||||
|
|
||||||
|
async def _validate(self, token: str) -> Any:
|
||||||
|
"""Check prefix and call the validator."""
|
||||||
|
if self._prefix is not None and not token.startswith(self._prefix):
|
||||||
|
raise UnauthorizedError()
|
||||||
|
return await self._validator(token, **self._kwargs)
|
||||||
|
|
||||||
|
async def extract(self, request: Any) -> str | None:
|
||||||
|
"""Extract the raw credential from the request without validating.
|
||||||
|
|
||||||
|
Returns ``None`` if no ``Authorization: Bearer`` header is present,
|
||||||
|
the token is empty, or the token does not match the configured prefix.
|
||||||
|
The prefix is included in the returned value.
|
||||||
|
"""
|
||||||
|
auth = request.headers.get("Authorization", "")
|
||||||
|
if not auth.startswith("Bearer "):
|
||||||
|
return None
|
||||||
|
token = auth[7:]
|
||||||
|
if not token:
|
||||||
|
return None
|
||||||
|
if self._prefix is not None and not token.startswith(self._prefix):
|
||||||
|
return None
|
||||||
|
return token
|
||||||
|
|
||||||
|
async def authenticate(self, credential: str) -> Any:
|
||||||
|
"""Validate a credential and return the identity.
|
||||||
|
|
||||||
|
Calls ``await validator(credential, **kwargs)`` where ``kwargs`` are
|
||||||
|
the extra keyword arguments provided at instantiation.
|
||||||
|
"""
|
||||||
|
return await self._validate(credential)
|
||||||
|
|
||||||
|
def require(self, **kwargs: Any) -> "BearerTokenAuth":
|
||||||
|
"""Return a new instance with additional (or overriding) validator kwargs."""
|
||||||
|
return BearerTokenAuth(
|
||||||
|
self._validator,
|
||||||
|
prefix=self._prefix,
|
||||||
|
**{**self._kwargs, **kwargs},
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_token(self, nbytes: int = 32) -> str:
|
||||||
|
"""Generate a secure random token for this auth source.
|
||||||
|
|
||||||
|
Returns a URL-safe random token. If a prefix is configured it is
|
||||||
|
prepended — the returned value is what you store in your database
|
||||||
|
and return to the client as-is.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nbytes: Number of random bytes before base64 encoding. The
|
||||||
|
resulting string is ``ceil(nbytes * 4 / 3)`` characters
|
||||||
|
(43 chars for the default 32 bytes). Defaults to 32.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A ready-to-use token string (e.g. ``"user_Xk3..."``).
|
||||||
|
"""
|
||||||
|
token = secrets.token_urlsafe(nbytes)
|
||||||
|
if self._prefix is not None:
|
||||||
|
return f"{self._prefix}{token}"
|
||||||
|
return token
|
||||||
139
src/fastapi_toolsets/security/sources/cookie.py
Normal file
139
src/fastapi_toolsets/security/sources/cookie.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
"""Cookie-based authentication source."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import Annotated, Any, Callable
|
||||||
|
|
||||||
|
from fastapi import Depends, Request, Response
|
||||||
|
from fastapi.security import APIKeyCookie, SecurityScopes
|
||||||
|
|
||||||
|
from fastapi_toolsets.exceptions import UnauthorizedError
|
||||||
|
|
||||||
|
from ..abc import AuthSource, _ensure_async
|
||||||
|
|
||||||
|
|
||||||
|
class CookieAuth(AuthSource):
|
||||||
|
"""Cookie-based authentication source.
|
||||||
|
|
||||||
|
Wraps :class:`fastapi.security.APIKeyCookie` for OpenAPI documentation.
|
||||||
|
Optionally signs the cookie with HMAC-SHA256 to provide stateless, tamper-
|
||||||
|
proof sessions without any database entry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Cookie name.
|
||||||
|
validator: Sync or async callable that receives the cookie value
|
||||||
|
(plain, after signature verification when ``secret_key`` is set)
|
||||||
|
and any extra keyword arguments, and returns the authenticated
|
||||||
|
identity.
|
||||||
|
secret_key: When provided, the cookie is HMAC-SHA256 signed.
|
||||||
|
:meth:`set_cookie` embeds an expiry and signs the payload;
|
||||||
|
:meth:`extract` verifies the signature and expiry before handing
|
||||||
|
the plain value to the validator. When ``None`` (default), the raw
|
||||||
|
cookie value is passed to the validator as-is.
|
||||||
|
ttl: Cookie lifetime in seconds (default 24 h). Only used when
|
||||||
|
``secret_key`` is set.
|
||||||
|
**kwargs: Extra keyword arguments forwarded to the validator on every
|
||||||
|
call (e.g. ``role=Role.ADMIN``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
validator: Callable[..., Any],
|
||||||
|
*,
|
||||||
|
secret_key: str | None = None,
|
||||||
|
ttl: int = 86400,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
self._name = name
|
||||||
|
self._validator = _ensure_async(validator)
|
||||||
|
self._secret_key = secret_key
|
||||||
|
self._ttl = ttl
|
||||||
|
self._kwargs = kwargs
|
||||||
|
self._scheme = APIKeyCookie(name=name, auto_error=False)
|
||||||
|
|
||||||
|
async def _call(
|
||||||
|
security_scopes: SecurityScopes, # noqa: ARG001
|
||||||
|
value: Annotated[str | None, Depends(self._scheme)] = None,
|
||||||
|
) -> Any:
|
||||||
|
if value is None:
|
||||||
|
raise UnauthorizedError()
|
||||||
|
plain = self._verify(value)
|
||||||
|
return await self._validator(plain, **self._kwargs)
|
||||||
|
|
||||||
|
self._call_fn = _call
|
||||||
|
self.__signature__ = inspect.signature(_call)
|
||||||
|
|
||||||
|
def _hmac(self, data: str) -> str:
|
||||||
|
if self._secret_key is None:
|
||||||
|
raise RuntimeError("_hmac called without secret_key configured")
|
||||||
|
return hmac.new(
|
||||||
|
self._secret_key.encode(), data.encode(), hashlib.sha256
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
def _sign(self, value: str) -> str:
|
||||||
|
data = base64.urlsafe_b64encode(
|
||||||
|
json.dumps({"v": value, "exp": int(time.time()) + self._ttl}).encode()
|
||||||
|
).decode()
|
||||||
|
return f"{data}.{self._hmac(data)}"
|
||||||
|
|
||||||
|
def _verify(self, cookie_value: str) -> str:
|
||||||
|
"""Return the plain value, verifying HMAC + expiry when signed."""
|
||||||
|
if not self._secret_key:
|
||||||
|
return cookie_value
|
||||||
|
|
||||||
|
try:
|
||||||
|
data, sig = cookie_value.rsplit(".", 1)
|
||||||
|
except ValueError:
|
||||||
|
raise UnauthorizedError()
|
||||||
|
|
||||||
|
if not hmac.compare_digest(self._hmac(data), sig):
|
||||||
|
raise UnauthorizedError()
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = json.loads(base64.urlsafe_b64decode(data))
|
||||||
|
value: str = payload["v"]
|
||||||
|
exp: int = payload["exp"]
|
||||||
|
except Exception:
|
||||||
|
raise UnauthorizedError()
|
||||||
|
|
||||||
|
if exp < int(time.time()):
|
||||||
|
raise UnauthorizedError()
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
async def extract(self, request: Request) -> str | None:
|
||||||
|
return request.cookies.get(self._name)
|
||||||
|
|
||||||
|
async def authenticate(self, credential: str) -> Any:
|
||||||
|
plain = self._verify(credential)
|
||||||
|
return await self._validator(plain, **self._kwargs)
|
||||||
|
|
||||||
|
def require(self, **kwargs: Any) -> "CookieAuth":
|
||||||
|
"""Return a new instance with additional (or overriding) validator kwargs."""
|
||||||
|
return CookieAuth(
|
||||||
|
self._name,
|
||||||
|
self._validator,
|
||||||
|
secret_key=self._secret_key,
|
||||||
|
ttl=self._ttl,
|
||||||
|
**{**self._kwargs, **kwargs},
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_cookie(self, response: Response, value: str) -> None:
|
||||||
|
"""Attach the cookie to *response*, signing it when ``secret_key`` is set."""
|
||||||
|
cookie_value = self._sign(value) if self._secret_key else value
|
||||||
|
response.set_cookie(
|
||||||
|
self._name,
|
||||||
|
cookie_value,
|
||||||
|
httponly=True,
|
||||||
|
samesite="lax",
|
||||||
|
max_age=self._ttl,
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_cookie(self, response: Response) -> None:
|
||||||
|
"""Clear the session cookie (logout)."""
|
||||||
|
response.delete_cookie(self._name, httponly=True, samesite="lax")
|
||||||
67
src/fastapi_toolsets/security/sources/header.py
Normal file
67
src/fastapi_toolsets/security/sources/header.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""API key header authentication source."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from typing import Annotated, Any, Callable
|
||||||
|
|
||||||
|
from fastapi import Depends, Request
|
||||||
|
from fastapi.security import APIKeyHeader, SecurityScopes
|
||||||
|
|
||||||
|
from fastapi_toolsets.exceptions import UnauthorizedError
|
||||||
|
|
||||||
|
from ..abc import AuthSource, _ensure_async
|
||||||
|
|
||||||
|
|
||||||
|
class APIKeyHeaderAuth(AuthSource):
|
||||||
|
"""API key header authentication source.
|
||||||
|
|
||||||
|
Wraps :class:`fastapi.security.APIKeyHeader` for OpenAPI documentation.
|
||||||
|
The validator is called as ``await validator(api_key, **kwargs)``
|
||||||
|
where ``kwargs`` are the extra keyword arguments provided at instantiation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: HTTP header name that carries the API key (e.g. ``"X-API-Key"``).
|
||||||
|
validator: Sync or async callable that receives the API key and any
|
||||||
|
extra keyword arguments, and returns the authenticated identity.
|
||||||
|
Should raise :class:`~fastapi_toolsets.exceptions.UnauthorizedError`
|
||||||
|
on failure.
|
||||||
|
**kwargs: Extra keyword arguments forwarded to the validator on every
|
||||||
|
call (e.g. ``role=Role.ADMIN``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
validator: Callable[..., Any],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
self._name = name
|
||||||
|
self._validator = _ensure_async(validator)
|
||||||
|
self._kwargs = kwargs
|
||||||
|
self._scheme = APIKeyHeader(name=name, auto_error=False)
|
||||||
|
|
||||||
|
async def _call(
|
||||||
|
security_scopes: SecurityScopes, # noqa: ARG001
|
||||||
|
api_key: Annotated[str | None, Depends(self._scheme)] = None,
|
||||||
|
) -> Any:
|
||||||
|
if api_key is None:
|
||||||
|
raise UnauthorizedError()
|
||||||
|
return await self._validator(api_key, **self._kwargs)
|
||||||
|
|
||||||
|
self._call_fn = _call
|
||||||
|
self.__signature__ = inspect.signature(_call)
|
||||||
|
|
||||||
|
async def extract(self, request: Request) -> str | None:
|
||||||
|
"""Extract the API key from the configured header."""
|
||||||
|
return request.headers.get(self._name) or None
|
||||||
|
|
||||||
|
async def authenticate(self, credential: str) -> Any:
|
||||||
|
"""Validate a credential and return the identity."""
|
||||||
|
return await self._validator(credential, **self._kwargs)
|
||||||
|
|
||||||
|
def require(self, **kwargs: Any) -> "APIKeyHeaderAuth":
|
||||||
|
"""Return a new instance with additional (or overriding) validator kwargs."""
|
||||||
|
return APIKeyHeaderAuth(
|
||||||
|
self._name,
|
||||||
|
self._validator,
|
||||||
|
**{**self._kwargs, **kwargs},
|
||||||
|
)
|
||||||
119
src/fastapi_toolsets/security/sources/multi.py
Normal file
119
src/fastapi_toolsets/security/sources/multi.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
"""MultiAuth: combine multiple authentication sources into a single callable."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from fastapi.security import SecurityScopes
|
||||||
|
|
||||||
|
from fastapi_toolsets.exceptions import UnauthorizedError
|
||||||
|
|
||||||
|
from ..abc import AuthSource
|
||||||
|
|
||||||
|
|
||||||
|
class MultiAuth:
|
||||||
|
"""Combine multiple authentication sources into a single callable.
|
||||||
|
|
||||||
|
Sources are tried in order; the first one whose
|
||||||
|
:meth:`~AuthSource.extract` returns a non-``None`` credential wins.
|
||||||
|
Its :meth:`~AuthSource.authenticate` is called and the result returned.
|
||||||
|
|
||||||
|
If a credential is found but the validator raises, the exception propagates
|
||||||
|
immediately — the remaining sources are **not** tried. This prevents
|
||||||
|
silent fallthrough on invalid credentials.
|
||||||
|
|
||||||
|
If no source provides a credential,
|
||||||
|
:class:`~fastapi_toolsets.exceptions.UnauthorizedError` is raised.
|
||||||
|
|
||||||
|
The :meth:`~AuthSource.extract` method of each source performs only
|
||||||
|
string matching (no I/O), so prefix-based dispatch is essentially free.
|
||||||
|
|
||||||
|
Any :class:`~AuthSource` subclass — including user-defined ones — can be
|
||||||
|
passed as a source.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*sources: Auth source instances to try in order.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
user_bearer = BearerTokenAuth(verify_user, prefix="user_")
|
||||||
|
org_bearer = BearerTokenAuth(verify_org, prefix="org_")
|
||||||
|
cookie = CookieAuth("session", verify_session)
|
||||||
|
|
||||||
|
multi = MultiAuth(user_bearer, org_bearer, cookie)
|
||||||
|
|
||||||
|
@app.get("/data")
|
||||||
|
async def data_route(user = Security(multi)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
# Apply a shared requirement to all sources at once
|
||||||
|
@app.get("/admin")
|
||||||
|
async def admin_route(user = Security(multi.require(role=Role.ADMIN))):
|
||||||
|
return user
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *sources: AuthSource) -> None:
|
||||||
|
self._sources = sources
|
||||||
|
|
||||||
|
async def _call(
|
||||||
|
request: Request,
|
||||||
|
security_scopes: SecurityScopes, # noqa: ARG001
|
||||||
|
**kwargs: Any, # noqa: ARG001 — absorbs scheme values injected by FastAPI
|
||||||
|
) -> Any:
|
||||||
|
for source in self._sources:
|
||||||
|
credential = await source.extract(request)
|
||||||
|
if credential is not None:
|
||||||
|
return await source.authenticate(credential)
|
||||||
|
raise UnauthorizedError()
|
||||||
|
|
||||||
|
self._call_fn = _call
|
||||||
|
|
||||||
|
# Build a merged signature that includes the security-scheme Depends()
|
||||||
|
# parameters from every source so FastAPI registers them in OpenAPI docs.
|
||||||
|
seen: set[str] = {"request", "security_scopes"}
|
||||||
|
merged: list[inspect.Parameter] = [
|
||||||
|
inspect.Parameter(
|
||||||
|
"request",
|
||||||
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||||
|
annotation=Request,
|
||||||
|
),
|
||||||
|
inspect.Parameter(
|
||||||
|
"security_scopes",
|
||||||
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||||
|
annotation=SecurityScopes,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for i, source in enumerate(sources):
|
||||||
|
for name, param in inspect.signature(source).parameters.items():
|
||||||
|
if name in seen:
|
||||||
|
continue
|
||||||
|
merged.append(param.replace(name=f"_s{i}_{name}"))
|
||||||
|
seen.add(name)
|
||||||
|
self.__signature__ = inspect.Signature(merged, return_annotation=Any)
|
||||||
|
|
||||||
|
async def __call__(self, **kwargs: Any) -> Any:
|
||||||
|
return await self._call_fn(**kwargs)
|
||||||
|
|
||||||
|
def require(self, **kwargs: Any) -> "MultiAuth":
|
||||||
|
"""Return a new :class:`MultiAuth` with kwargs forwarded to each source.
|
||||||
|
|
||||||
|
Calls ``.require(**kwargs)`` on every source that supports it. Sources
|
||||||
|
that do not implement ``.require()`` (e.g. custom :class:`~AuthSource`
|
||||||
|
subclasses) are passed through unchanged.
|
||||||
|
|
||||||
|
New kwargs are merged over each source's existing kwargs — new values
|
||||||
|
win on conflict::
|
||||||
|
|
||||||
|
multi = MultiAuth(bearer, cookie)
|
||||||
|
|
||||||
|
@app.get("/admin")
|
||||||
|
async def admin(user = Security(multi.require(role=Role.ADMIN))):
|
||||||
|
return user
|
||||||
|
"""
|
||||||
|
new_sources = tuple(
|
||||||
|
cast(Any, source).require(**kwargs)
|
||||||
|
if hasattr(source, "require")
|
||||||
|
else source
|
||||||
|
for source in self._sources
|
||||||
|
)
|
||||||
|
return MultiAuth(*new_sources)
|
||||||
27
src/fastapi_toolsets/types.py
Normal file
27
src/fastapi_toolsets/types.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
"""Shared type aliases for the fastapi-toolsets package."""
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator, Callable, Mapping
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute
|
||||||
|
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||||
|
from sqlalchemy.sql.elements import ColumnElement
|
||||||
|
|
||||||
|
# Generic TypeVars
|
||||||
|
DataT = TypeVar("DataT")
|
||||||
|
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
||||||
|
SchemaType = TypeVar("SchemaType", bound=BaseModel)
|
||||||
|
|
||||||
|
# CRUD type aliases
|
||||||
|
JoinType = list[tuple[type[DeclarativeBase] | Any, Any]]
|
||||||
|
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
|
||||||
|
OrderByClause = ColumnElement[Any] | QueryableAttribute[Any]
|
||||||
|
|
||||||
|
# Search / facet type aliases
|
||||||
|
SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...]
|
||||||
|
FacetFieldType = SearchFieldType
|
||||||
|
|
||||||
|
# Dependency type aliases
|
||||||
|
SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]] | Any
|
||||||
@@ -1,27 +1,38 @@
|
|||||||
"""Shared pytest fixtures for fastapi-utils tests."""
|
"""Shared pytest fixtures for fastapi-utils tests."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import ForeignKey, String
|
import datetime
|
||||||
|
import decimal
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
Column,
|
||||||
|
Date,
|
||||||
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
Integer,
|
||||||
|
JSON,
|
||||||
|
Numeric,
|
||||||
|
String,
|
||||||
|
Table,
|
||||||
|
Uuid,
|
||||||
|
)
|
||||||
|
from sqlalchemy.dialects.postgresql import ARRAY
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from fastapi_toolsets.crud import CrudFactory
|
from fastapi_toolsets.crud import CrudFactory
|
||||||
|
from fastapi_toolsets.schemas import PydanticBase
|
||||||
|
|
||||||
# PostgreSQL connection URL from environment or default for local development
|
DATABASE_URL = os.getenv(
|
||||||
DATABASE_URL = os.getenv("DATABASE_URL") or os.getenv(
|
key="DATABASE_URL",
|
||||||
"TEST_DATABASE_URL",
|
default="postgresql+asyncpg://postgres:postgres@localhost:5432/postgres",
|
||||||
"postgresql+asyncpg://postgres:postgres@localhost:5432/fastapi_toolsets_test",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Test Models
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class Base(DeclarativeBase):
|
class Base(DeclarativeBase):
|
||||||
"""Base class for test models."""
|
"""Base class for test models."""
|
||||||
|
|
||||||
@@ -33,7 +44,7 @@ class Role(Base):
|
|||||||
|
|
||||||
__tablename__ = "roles"
|
__tablename__ = "roles"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True)
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
name: Mapped[str] = mapped_column(String(50), unique=True)
|
name: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
|
|
||||||
users: Mapped[list["User"]] = relationship(back_populates="role")
|
users: Mapped[list["User"]] = relationship(back_populates="role")
|
||||||
@@ -44,36 +55,123 @@ class User(Base):
|
|||||||
|
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True)
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
username: Mapped[str] = mapped_column(String(50), unique=True)
|
username: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
email: Mapped[str] = mapped_column(String(100), unique=True)
|
email: Mapped[str] = mapped_column(String(100), unique=True)
|
||||||
is_active: Mapped[bool] = mapped_column(default=True)
|
is_active: Mapped[bool] = mapped_column(default=True)
|
||||||
role_id: Mapped[int | None] = mapped_column(ForeignKey("roles.id"), nullable=True)
|
notes: Mapped[str | None]
|
||||||
|
role_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
|
ForeignKey("roles.id"), nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
role: Mapped[Role | None] = relationship(back_populates="users")
|
role: Mapped[Role | None] = relationship(back_populates="users")
|
||||||
|
|
||||||
|
|
||||||
|
class Tag(Base):
|
||||||
|
"""Test tag model."""
|
||||||
|
|
||||||
|
__tablename__ = "tags"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
|
name: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
|
|
||||||
|
|
||||||
|
post_tags = Table(
|
||||||
|
"post_tags",
|
||||||
|
Base.metadata,
|
||||||
|
Column(
|
||||||
|
"post_id", Uuid, ForeignKey("posts.id", ondelete="CASCADE"), primary_key=True
|
||||||
|
),
|
||||||
|
Column("tag_id", Uuid, ForeignKey("tags.id", ondelete="CASCADE"), primary_key=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IntRole(Base):
|
||||||
|
"""Test role model with auto-increment integer PK."""
|
||||||
|
|
||||||
|
__tablename__ = "int_roles"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
name: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
|
|
||||||
|
|
||||||
|
class Permission(Base):
|
||||||
|
"""Test model with composite primary key."""
|
||||||
|
|
||||||
|
__tablename__ = "permissions"
|
||||||
|
|
||||||
|
subject: Mapped[str] = mapped_column(String(50), primary_key=True)
|
||||||
|
action: Mapped[str] = mapped_column(String(50), primary_key=True)
|
||||||
|
|
||||||
|
|
||||||
|
class Event(Base):
|
||||||
|
"""Test model with DateTime and Date cursor columns."""
|
||||||
|
|
||||||
|
__tablename__ = "events"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
|
name: Mapped[str] = mapped_column(String(100))
|
||||||
|
occurred_at: Mapped[datetime.datetime] = mapped_column(DateTime)
|
||||||
|
scheduled_date: Mapped[datetime.date] = mapped_column(Date)
|
||||||
|
|
||||||
|
|
||||||
|
class Product(Base):
|
||||||
|
"""Test model with Numeric cursor column."""
|
||||||
|
|
||||||
|
__tablename__ = "products"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
|
name: Mapped[str] = mapped_column(String(100))
|
||||||
|
price: Mapped[decimal.Decimal] = mapped_column(Numeric(10, 2))
|
||||||
|
|
||||||
|
|
||||||
class Post(Base):
|
class Post(Base):
|
||||||
"""Test post model."""
|
"""Test post model."""
|
||||||
|
|
||||||
__tablename__ = "posts"
|
__tablename__ = "posts"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True)
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
title: Mapped[str] = mapped_column(String(200))
|
title: Mapped[str] = mapped_column(String(200))
|
||||||
content: Mapped[str] = mapped_column(String(1000), default="")
|
content: Mapped[str] = mapped_column(String(1000), default="")
|
||||||
is_published: Mapped[bool] = mapped_column(default=False)
|
is_published: Mapped[bool] = mapped_column(default=False)
|
||||||
author_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
|
author_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
|
||||||
|
|
||||||
|
tags: Mapped[list[Tag]] = relationship(secondary=post_tags)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
class Transfer(Base):
|
||||||
# Test Schemas
|
"""Test model with two FKs to the same table (users)."""
|
||||||
# =============================================================================
|
|
||||||
|
__tablename__ = "transfers"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
|
amount: Mapped[str] = mapped_column(String(50))
|
||||||
|
sender_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
|
||||||
|
receiver_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
|
||||||
|
|
||||||
|
|
||||||
|
class Article(Base):
|
||||||
|
"""Test article model with ARRAY and JSON columns."""
|
||||||
|
|
||||||
|
__tablename__ = "articles"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
|
title: Mapped[str] = mapped_column(String(200))
|
||||||
|
labels: Mapped[list[str]] = mapped_column(ARRAY(String))
|
||||||
|
metadata_: Mapped[dict | None] = mapped_column("metadata", JSON, nullable=True)
|
||||||
|
|
||||||
|
|
||||||
class RoleCreate(BaseModel):
|
class RoleCreate(BaseModel):
|
||||||
"""Schema for creating a role."""
|
"""Schema for creating a role."""
|
||||||
|
|
||||||
id: int | None = None
|
id: uuid.UUID | None = None
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class RoleRead(PydanticBase):
|
||||||
|
"""Schema for reading a role."""
|
||||||
|
|
||||||
|
id: uuid.UUID
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
@@ -86,11 +184,19 @@ class RoleUpdate(BaseModel):
|
|||||||
class UserCreate(BaseModel):
|
class UserCreate(BaseModel):
|
||||||
"""Schema for creating a user."""
|
"""Schema for creating a user."""
|
||||||
|
|
||||||
id: int | None = None
|
id: uuid.UUID | None = None
|
||||||
username: str
|
username: str
|
||||||
email: str
|
email: str
|
||||||
is_active: bool = True
|
is_active: bool = True
|
||||||
role_id: int | None = None
|
role_id: uuid.UUID | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class UserRead(PydanticBase):
|
||||||
|
"""Schema for reading a user (subset of fields — no email)."""
|
||||||
|
|
||||||
|
id: uuid.UUID
|
||||||
|
username: str
|
||||||
|
is_active: bool = True
|
||||||
|
|
||||||
|
|
||||||
class UserUpdate(BaseModel):
|
class UserUpdate(BaseModel):
|
||||||
@@ -99,17 +205,24 @@ class UserUpdate(BaseModel):
|
|||||||
username: str | None = None
|
username: str | None = None
|
||||||
email: str | None = None
|
email: str | None = None
|
||||||
is_active: bool | None = None
|
is_active: bool | None = None
|
||||||
role_id: int | None = None
|
role_id: uuid.UUID | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TagCreate(BaseModel):
|
||||||
|
"""Schema for creating a tag."""
|
||||||
|
|
||||||
|
id: uuid.UUID | None = None
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
class PostCreate(BaseModel):
|
class PostCreate(BaseModel):
|
||||||
"""Schema for creating a post."""
|
"""Schema for creating a post."""
|
||||||
|
|
||||||
id: int | None = None
|
id: uuid.UUID | None = None
|
||||||
title: str
|
title: str
|
||||||
content: str = ""
|
content: str = ""
|
||||||
is_published: bool = False
|
is_published: bool = False
|
||||||
author_id: int
|
author_id: uuid.UUID
|
||||||
|
|
||||||
|
|
||||||
class PostUpdate(BaseModel):
|
class PostUpdate(BaseModel):
|
||||||
@@ -120,18 +233,115 @@ class PostUpdate(BaseModel):
|
|||||||
is_published: bool | None = None
|
is_published: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
class PostM2MCreate(BaseModel):
|
||||||
# CRUD Classes
|
"""Schema for creating a post with M2M tag IDs."""
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
id: uuid.UUID | None = None
|
||||||
|
title: str
|
||||||
|
content: str = ""
|
||||||
|
is_published: bool = False
|
||||||
|
author_id: uuid.UUID
|
||||||
|
tag_ids: list[uuid.UUID] = []
|
||||||
|
|
||||||
|
|
||||||
|
class PostM2MUpdate(BaseModel):
|
||||||
|
"""Schema for updating a post with M2M tag IDs."""
|
||||||
|
|
||||||
|
title: str | None = None
|
||||||
|
content: str | None = None
|
||||||
|
is_published: bool | None = None
|
||||||
|
tag_ids: list[uuid.UUID] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class IntRoleRead(PydanticBase):
|
||||||
|
"""Schema for reading an IntRole."""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class IntRoleCreate(BaseModel):
|
||||||
|
"""Schema for creating an IntRole."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class EventRead(PydanticBase):
|
||||||
|
"""Schema for reading an Event."""
|
||||||
|
|
||||||
|
id: uuid.UUID
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class EventCreate(BaseModel):
|
||||||
|
"""Schema for creating an Event."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
occurred_at: datetime.datetime
|
||||||
|
scheduled_date: datetime.date
|
||||||
|
|
||||||
|
|
||||||
|
class ProductRead(PydanticBase):
|
||||||
|
"""Schema for reading a Product."""
|
||||||
|
|
||||||
|
id: uuid.UUID
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class ProductCreate(BaseModel):
|
||||||
|
"""Schema for creating a Product."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
price: decimal.Decimal
|
||||||
|
|
||||||
|
|
||||||
|
class ArticleCreate(BaseModel):
|
||||||
|
"""Schema for creating an article."""
|
||||||
|
|
||||||
|
id: uuid.UUID | None = None
|
||||||
|
title: str
|
||||||
|
labels: list[str] = []
|
||||||
|
|
||||||
|
|
||||||
|
class ArticleRead(PydanticBase):
|
||||||
|
"""Schema for reading an article."""
|
||||||
|
|
||||||
|
id: uuid.UUID
|
||||||
|
title: str
|
||||||
|
labels: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class TransferCreate(BaseModel):
|
||||||
|
"""Schema for creating a transfer."""
|
||||||
|
|
||||||
|
id: uuid.UUID | None = None
|
||||||
|
amount: str
|
||||||
|
sender_id: uuid.UUID
|
||||||
|
receiver_id: uuid.UUID
|
||||||
|
|
||||||
|
|
||||||
|
class TransferRead(PydanticBase):
|
||||||
|
"""Schema for reading a transfer."""
|
||||||
|
|
||||||
|
id: uuid.UUID
|
||||||
|
amount: str
|
||||||
|
|
||||||
|
|
||||||
|
TransferCrud = CrudFactory(Transfer)
|
||||||
|
ArticleCrud = CrudFactory(Article)
|
||||||
RoleCrud = CrudFactory(Role)
|
RoleCrud = CrudFactory(Role)
|
||||||
|
RoleCursorCrud = CrudFactory(Role, cursor_column=Role.id)
|
||||||
|
IntRoleCursorCrud = CrudFactory(IntRole, cursor_column=IntRole.id)
|
||||||
UserCrud = CrudFactory(User)
|
UserCrud = CrudFactory(User)
|
||||||
|
UserCursorCrud = CrudFactory(User, cursor_column=User.id)
|
||||||
PostCrud = CrudFactory(Post)
|
PostCrud = CrudFactory(Post)
|
||||||
|
TagCrud = CrudFactory(Tag)
|
||||||
|
PostM2MCrud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags})
|
||||||
# =============================================================================
|
EventCrud = CrudFactory(Event)
|
||||||
# Fixtures
|
EventDateTimeCursorCrud = CrudFactory(Event, cursor_column=Event.occurred_at)
|
||||||
# =============================================================================
|
EventDateCursorCrud = CrudFactory(Event, cursor_column=Event.scheduled_date)
|
||||||
|
ProductCrud = CrudFactory(Product)
|
||||||
|
ProductNumericCursorCrud = CrudFactory(Product, cursor_column=Product.price)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -170,30 +380,3 @@ async def db_session(engine):
|
|||||||
# Drop tables after test
|
# Drop tables after test
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.run_sync(Base.metadata.drop_all)
|
await conn.run_sync(Base.metadata.drop_all)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_role_data() -> RoleCreate:
|
|
||||||
"""Sample role creation data."""
|
|
||||||
return RoleCreate(name="admin")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_user_data() -> UserCreate:
|
|
||||||
"""Sample user creation data."""
|
|
||||||
return UserCreate(
|
|
||||||
username="testuser",
|
|
||||||
email="test@example.com",
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_post_data() -> PostCreate:
|
|
||||||
"""Sample post creation data."""
|
|
||||||
return PostCreate(
|
|
||||||
title="Test Post",
|
|
||||||
content="Test content",
|
|
||||||
is_published=True,
|
|
||||||
author_id=1,
|
|
||||||
)
|
|
||||||
|
|||||||
543
tests/test_cli.py
Normal file
543
tests/test_cli.py
Normal file
@@ -0,0 +1,543 @@
|
|||||||
|
"""Tests for fastapi_toolsets.cli module."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli.config import (
|
||||||
|
get_config_value,
|
||||||
|
get_custom_cli,
|
||||||
|
get_db_context,
|
||||||
|
get_fixtures_registry,
|
||||||
|
import_from_string,
|
||||||
|
)
|
||||||
|
from fastapi_toolsets.cli.pyproject import find_pyproject, load_pyproject
|
||||||
|
from fastapi_toolsets.cli.utils import async_command
|
||||||
|
from fastapi_toolsets.fixtures import FixtureRegistry
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
|
||||||
|
|
||||||
|
class TestPyproject:
|
||||||
|
"""Tests for pyproject.toml discovery and loading."""
|
||||||
|
|
||||||
|
def test_find_pyproject_in_current_dir(self, tmp_path, monkeypatch):
|
||||||
|
"""Finds pyproject.toml in current directory."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[project]\nname = 'test'\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = find_pyproject()
|
||||||
|
assert result == pyproject
|
||||||
|
|
||||||
|
def test_find_pyproject_in_parent_dir(self, tmp_path, monkeypatch):
|
||||||
|
"""Finds pyproject.toml in parent directory."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[project]\nname = 'test'\n")
|
||||||
|
subdir = tmp_path / "src" / "app"
|
||||||
|
subdir.mkdir(parents=True)
|
||||||
|
monkeypatch.chdir(subdir)
|
||||||
|
|
||||||
|
result = find_pyproject()
|
||||||
|
assert result == pyproject
|
||||||
|
|
||||||
|
def test_find_pyproject_not_found(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns None when no pyproject.toml exists."""
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
result = find_pyproject()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_load_pyproject_returns_tool_config(self, tmp_path, monkeypatch):
|
||||||
|
"""load_pyproject returns the [tool.fastapi-toolsets] section."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text(
|
||||||
|
'[tool.fastapi-toolsets]\nfixtures = "app.fixtures:registry"\n'
|
||||||
|
)
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = load_pyproject()
|
||||||
|
assert result == {"fixtures": "app.fixtures:registry"}
|
||||||
|
|
||||||
|
def test_load_pyproject_empty_when_no_file(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns empty dict when no pyproject.toml exists."""
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
result = load_pyproject()
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
def test_load_pyproject_empty_when_no_tool_section(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns empty dict when no [tool.fastapi-toolsets] section."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[project]\nname = 'test'\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = load_pyproject()
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
def test_load_pyproject_invalid_toml(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns empty dict when pyproject.toml is invalid."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("invalid toml {{{")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = load_pyproject()
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestImportFromString:
|
||||||
|
"""Tests for import_from_string function."""
|
||||||
|
|
||||||
|
def test_import_valid_path(self):
|
||||||
|
"""Import valid module:attribute path."""
|
||||||
|
result = import_from_string("fastapi_toolsets.fixtures:FixtureRegistry")
|
||||||
|
assert result is FixtureRegistry
|
||||||
|
|
||||||
|
def test_import_without_colon_raises_error(self):
|
||||||
|
"""Import path without colon raises error."""
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
import_from_string("fastapi_toolsets.fixtures.FixtureRegistry")
|
||||||
|
assert "Expected format: 'module:attribute'" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_import_nonexistent_module_raises_error(self):
|
||||||
|
"""Import nonexistent module raises error."""
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
import_from_string("nonexistent.module:something")
|
||||||
|
assert "Cannot import module" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_import_nonexistent_attribute_raises_error(self):
|
||||||
|
"""Import nonexistent attribute raises error."""
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
import_from_string("fastapi_toolsets.fixtures:NonexistentClass")
|
||||||
|
assert "has no attribute" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetConfigValue:
|
||||||
|
"""Tests for get_config_value function."""
|
||||||
|
|
||||||
|
def test_get_existing_value(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns value when key exists."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text('[tool.fastapi-toolsets]\nfixtures = "app:registry"\n')
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = get_config_value("fixtures")
|
||||||
|
assert result == "app:registry"
|
||||||
|
|
||||||
|
def test_get_missing_value_returns_none(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns None when key is missing and not required."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = get_config_value("fixtures")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_get_missing_value_required_raises_error(self, tmp_path, monkeypatch):
|
||||||
|
"""Raises error when key is missing and required."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
get_config_value("fixtures", required=True)
|
||||||
|
assert "No 'fixtures' configured" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetFixturesRegistry:
|
||||||
|
"""Tests for get_fixtures_registry function."""
|
||||||
|
|
||||||
|
def test_raises_when_not_configured(self, tmp_path, monkeypatch):
|
||||||
|
"""Raises error when fixtures not configured."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
get_fixtures_registry()
|
||||||
|
assert "No 'fixtures' configured" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_raises_when_not_registry_instance(self, tmp_path, monkeypatch):
|
||||||
|
"""Raises error when imported object is not a FixtureRegistry."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text(
|
||||||
|
'[tool.fastapi-toolsets]\nfixtures = "my_fixtures:registry"\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
fixtures_file = tmp_path / "my_fixtures.py"
|
||||||
|
fixtures_file.write_text("registry = 'not a registry'\n")
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
if str(tmp_path) not in sys.path:
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
get_fixtures_registry()
|
||||||
|
assert "must be a FixtureRegistry instance" in str(exc_info.value)
|
||||||
|
finally:
|
||||||
|
if str(tmp_path) in sys.path:
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
if "my_fixtures" in sys.modules:
|
||||||
|
del sys.modules["my_fixtures"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetDbContext:
|
||||||
|
"""Tests for get_db_context function."""
|
||||||
|
|
||||||
|
def test_raises_when_not_configured(self, tmp_path, monkeypatch):
|
||||||
|
"""Raises error when db_context not configured."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
get_db_context()
|
||||||
|
assert "No 'db_context' configured" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetCustomCli:
|
||||||
|
"""Tests for get_custom_cli function."""
|
||||||
|
|
||||||
|
def test_returns_none_when_not_configured(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns None when custom_cli not configured."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = get_custom_cli()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_raises_when_not_typer_instance(self, tmp_path, monkeypatch):
|
||||||
|
"""Raises error when imported object is not a Typer instance."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text('[tool.fastapi-toolsets]\ncustom_cli = "my_cli:cli"\n')
|
||||||
|
|
||||||
|
cli_file = tmp_path / "my_cli.py"
|
||||||
|
cli_file.write_text("cli = 'not a typer'\n")
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
if str(tmp_path) not in sys.path:
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
get_custom_cli()
|
||||||
|
assert "must be a Typer instance" in str(exc_info.value)
|
||||||
|
finally:
|
||||||
|
if str(tmp_path) in sys.path:
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
if "my_cli" in sys.modules:
|
||||||
|
del sys.modules["my_cli"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestCliApp:
|
||||||
|
"""Tests for CLI application."""
|
||||||
|
|
||||||
|
def test_cli_help(self, tmp_path, monkeypatch):
|
||||||
|
"""CLI shows help without fixtures."""
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
# Need to reload the module to pick up new cwd
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app
|
||||||
|
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
result = runner.invoke(app.cli, ["--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "CLI utilities for FastAPI projects" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
class TestFixturesCli:
|
||||||
|
"""Tests for fixtures CLI commands."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def cli_env(self, tmp_path, monkeypatch):
|
||||||
|
"""Set up CLI environment with fixtures config."""
|
||||||
|
# Create pyproject.toml
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text(
|
||||||
|
"[tool.fastapi-toolsets]\n"
|
||||||
|
'fixtures = "fixtures:registry"\n'
|
||||||
|
'db_context = "db:get_session"\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create fixtures module
|
||||||
|
fixtures_file = tmp_path / "fixtures.py"
|
||||||
|
fixtures_file.write_text(
|
||||||
|
"from fastapi_toolsets.fixtures import FixtureRegistry, Context\n"
|
||||||
|
"\n"
|
||||||
|
"registry = FixtureRegistry()\n"
|
||||||
|
"\n"
|
||||||
|
"@registry.register(contexts=[Context.BASE])\n"
|
||||||
|
"def roles():\n"
|
||||||
|
' return [{"id": 1, "name": "admin"}, {"id": 2, "name": "user"}]\n'
|
||||||
|
"\n"
|
||||||
|
'@registry.register(depends_on=["roles"], contexts=[Context.TESTING])\n'
|
||||||
|
"def users():\n"
|
||||||
|
' return [{"id": 1, "name": "alice", "role_id": 1}]\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create db module
|
||||||
|
db_file = tmp_path / "db.py"
|
||||||
|
db_file.write_text(
|
||||||
|
"from contextlib import asynccontextmanager\n"
|
||||||
|
"\n"
|
||||||
|
"@asynccontextmanager\n"
|
||||||
|
"async def get_session():\n"
|
||||||
|
" yield None\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
# Add tmp_path to sys.path for imports
|
||||||
|
if str(tmp_path) not in sys.path:
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
# Reload the CLI module to pick up new config
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app
|
||||||
|
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
yield tmp_path, app.cli
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
if str(tmp_path) in sys.path:
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
|
||||||
|
def test_fixtures_list(self, cli_env):
|
||||||
|
"""fixtures list shows registered fixtures."""
|
||||||
|
tmp_path, cli = cli_env
|
||||||
|
result = runner.invoke(cli, ["fixtures", "list"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "roles" in result.output
|
||||||
|
assert "users" in result.output
|
||||||
|
assert "Total: 2 fixture(s)" in result.output
|
||||||
|
|
||||||
|
def test_fixtures_list_with_context(self, cli_env):
|
||||||
|
"""fixtures list --context filters by context."""
|
||||||
|
tmp_path, cli = cli_env
|
||||||
|
result = runner.invoke(cli, ["fixtures", "list", "--context", "base"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "roles" in result.output
|
||||||
|
assert "users" not in result.output
|
||||||
|
assert "Total: 1 fixture(s)" in result.output
|
||||||
|
|
||||||
|
def test_fixtures_load_dry_run(self, cli_env):
|
||||||
|
"""fixtures load --dry-run shows what would be loaded."""
|
||||||
|
tmp_path, cli = cli_env
|
||||||
|
result = runner.invoke(cli, ["fixtures", "load", "base", "--dry-run"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Fixtures to load" in result.output
|
||||||
|
assert "roles" in result.output
|
||||||
|
assert "[Dry run - no changes made]" in result.output
|
||||||
|
|
||||||
|
def test_fixtures_load_invalid_strategy(self, cli_env):
|
||||||
|
"""fixtures load with invalid strategy shows error."""
|
||||||
|
tmp_path, cli = cli_env
|
||||||
|
result = runner.invoke(
|
||||||
|
cli, ["fixtures", "load", "base", "--strategy", "invalid"]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code != 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestCliWithoutFixturesConfig:
|
||||||
|
"""Tests for CLI when fixtures is not configured."""
|
||||||
|
|
||||||
|
def test_no_fixtures_command(self, tmp_path, monkeypatch):
|
||||||
|
"""fixtures command is not available when not configured."""
|
||||||
|
# Create pyproject.toml without fixtures
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text('[project]\nname = "test"\n')
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
# Reload the CLI module
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app
|
||||||
|
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
result = runner.invoke(app.cli, ["--help"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "fixtures" not in result.output
|
||||||
|
|
||||||
|
|
||||||
|
class TestCustomCliConfig:
|
||||||
|
"""Tests for custom CLI configuration."""
|
||||||
|
|
||||||
|
def test_cli_with_custom_cli(self, tmp_path, monkeypatch):
|
||||||
|
"""CLI uses custom Typer instance when configured."""
|
||||||
|
import typer
|
||||||
|
|
||||||
|
# Create pyproject.toml with custom_cli config
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text('[tool.fastapi-toolsets]\ncustom_cli = "my_cli:cli"\n')
|
||||||
|
|
||||||
|
# Create custom CLI module with its own Typer and commands
|
||||||
|
cli_file = tmp_path / "my_cli.py"
|
||||||
|
cli_file.write_text(
|
||||||
|
"import typer\n"
|
||||||
|
"\n"
|
||||||
|
"cli = typer.Typer(name='my-app', help='My custom CLI')\n"
|
||||||
|
"\n"
|
||||||
|
"@cli.command()\n"
|
||||||
|
"def hello():\n"
|
||||||
|
' print("Hello from custom CLI!")\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
# Add tmp_path to sys.path for imports
|
||||||
|
if str(tmp_path) not in sys.path:
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
# Remove my_cli from sys.modules if it was previously loaded
|
||||||
|
if "my_cli" in sys.modules:
|
||||||
|
del sys.modules["my_cli"]
|
||||||
|
|
||||||
|
# Reload the CLI module to pick up new config
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app
|
||||||
|
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Verify custom CLI is used
|
||||||
|
assert isinstance(app.cli, typer.Typer)
|
||||||
|
|
||||||
|
result = runner.invoke(app.cli, ["--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "My custom CLI" in result.output
|
||||||
|
assert "hello" in result.output
|
||||||
|
|
||||||
|
result = runner.invoke(app.cli, ["hello"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Hello from custom CLI!" in result.output
|
||||||
|
finally:
|
||||||
|
if str(tmp_path) in sys.path:
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
if "my_cli" in sys.modules:
|
||||||
|
del sys.modules["my_cli"]
|
||||||
|
|
||||||
|
def test_custom_cli_with_fixtures(self, tmp_path, monkeypatch):
|
||||||
|
"""Custom CLI gets fixtures command added when configured."""
|
||||||
|
# Create pyproject.toml with both custom_cli and fixtures
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text(
|
||||||
|
"[tool.fastapi-toolsets]\n"
|
||||||
|
'custom_cli = "my_cli:cli"\n'
|
||||||
|
'fixtures = "fixtures:registry"\n'
|
||||||
|
'db_context = "db:get_session"\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create custom CLI module
|
||||||
|
cli_file = tmp_path / "my_cli.py"
|
||||||
|
cli_file.write_text(
|
||||||
|
"import typer\n"
|
||||||
|
"\n"
|
||||||
|
"cli = typer.Typer(name='my-app', help='My custom CLI')\n"
|
||||||
|
"\n"
|
||||||
|
"@cli.command()\n"
|
||||||
|
"def hello():\n"
|
||||||
|
' print("Hello!")\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create fixtures module
|
||||||
|
fixtures_file = tmp_path / "fixtures.py"
|
||||||
|
fixtures_file.write_text(
|
||||||
|
"from fastapi_toolsets.fixtures import FixtureRegistry\n"
|
||||||
|
"\n"
|
||||||
|
"registry = FixtureRegistry()\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create db module
|
||||||
|
db_file = tmp_path / "db.py"
|
||||||
|
db_file.write_text(
|
||||||
|
"from contextlib import asynccontextmanager\n"
|
||||||
|
"\n"
|
||||||
|
"@asynccontextmanager\n"
|
||||||
|
"async def get_session():\n"
|
||||||
|
" yield None\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
if str(tmp_path) not in sys.path:
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
for mod in ["my_cli", "fixtures", "db"]:
|
||||||
|
if mod in sys.modules:
|
||||||
|
del sys.modules[mod]
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app
|
||||||
|
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = runner.invoke(app.cli, ["--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
# Should have both custom command and fixtures
|
||||||
|
assert "hello" in result.output
|
||||||
|
assert "fixtures" in result.output
|
||||||
|
finally:
|
||||||
|
if str(tmp_path) in sys.path:
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
for mod in ["my_cli", "fixtures", "db"]:
|
||||||
|
if mod in sys.modules:
|
||||||
|
del sys.modules[mod]
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncCommand:
|
||||||
|
"""Tests for async_command decorator."""
|
||||||
|
|
||||||
|
def test_async_command_runs_coroutine(self):
|
||||||
|
"""async_command runs async function synchronously."""
|
||||||
|
|
||||||
|
@async_command
|
||||||
|
async def async_func(value: int) -> int:
|
||||||
|
return value * 2
|
||||||
|
|
||||||
|
result = async_func(21)
|
||||||
|
assert result == 42
|
||||||
|
|
||||||
|
def test_async_command_preserves_signature(self):
|
||||||
|
"""async_command preserves function signature."""
|
||||||
|
|
||||||
|
@async_command
|
||||||
|
async def async_func(name: str, count: int = 1) -> str:
|
||||||
|
return f"{name} x {count}"
|
||||||
|
|
||||||
|
result = async_func("test", count=3)
|
||||||
|
assert result == "test x 3"
|
||||||
|
|
||||||
|
def test_async_command_preserves_docstring(self):
|
||||||
|
"""async_command preserves function docstring."""
|
||||||
|
|
||||||
|
@async_command
|
||||||
|
async def async_func() -> None:
|
||||||
|
"""This is a docstring."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert async_func.__doc__ == """This is a docstring."""
|
||||||
|
|
||||||
|
def test_async_command_preserves_name(self):
|
||||||
|
"""async_command preserves function name."""
|
||||||
|
|
||||||
|
@async_command
|
||||||
|
async def my_async_function() -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert my_async_function.__name__ == "my_async_function"
|
||||||
2356
tests/test_crud.py
2356
tests/test_crud.py
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user