mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-16 06:36:26 +02:00
Compare commits
162 Commits
v0.2.0
...
bbe63edc46
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bbe63edc46 | ||
|
|
0b17c77dee | ||
|
|
bce71bfd42 | ||
|
|
2f1eb4d468 | ||
|
|
1f06eab11d | ||
|
|
fac9aa6f60 | ||
|
|
f310466697 | ||
|
|
32059dcb02 | ||
|
|
f027981e80 | ||
|
|
5c1487c24a | ||
|
|
ebaa61525f | ||
|
|
4829cfba73 | ||
|
|
9ca2da4213 | ||
|
|
0b3f097012 | ||
|
|
1890d696bf | ||
|
|
104285c6e5 | ||
|
|
f5afbbe37f | ||
|
|
f4698bea8a | ||
|
|
5215b921ae | ||
|
9dad59e25d
|
|||
|
|
29326ab532 | ||
|
|
04afef7e33 | ||
|
|
666c621fda | ||
|
460b760fa4
|
|||
|
|
65d0b0e0b1 | ||
|
|
2d49cd32db | ||
|
|
a5dd756d87 | ||
|
|
781cfb66c9 | ||
|
|
91b84f8146 | ||
|
|
396e381ac3 | ||
|
|
b4eb4c1ca9 | ||
|
c90717754f
|
|||
|
337985ef38
|
|||
|
|
b5e6dfe6fe | ||
|
|
6681b7ade7 | ||
|
|
6981c33dc8 | ||
|
|
0c7a99039c | ||
|
|
bcb5b0bfda | ||
|
100e1c1aa9
|
|||
|
|
db6c7a565f | ||
|
|
768e405554 | ||
|
|
f0223ebde4 | ||
|
|
f8c9bf69fe | ||
|
6d6fae5538
|
|||
|
|
fc9cd1f034 | ||
|
|
f82225f995 | ||
|
|
e62612a93a | ||
|
|
56f0ea291e | ||
|
|
ee896009ee | ||
|
|
65bf928e12 | ||
|
|
2e9c6c0c90 | ||
|
2c494fcd17
|
|||
|
|
fd7269a372 | ||
|
|
c863744012 | ||
|
|
aedcbf4e04 | ||
|
|
19c013bdec | ||
|
|
81407c3038 | ||
|
|
0fb00d44da | ||
|
|
19232d3436 | ||
|
1eafcb3873
|
|||
|
|
0d67fbb58d | ||
|
|
a59f098930 | ||
|
|
96e34ba8af | ||
|
|
26d649791f | ||
|
dde5183e68
|
|||
|
|
e4250a9910 | ||
|
|
4800941934 | ||
|
0cc21d2012
|
|||
|
|
a3245d50f0 | ||
|
|
baebf022f6 | ||
|
|
96d445e3f3 | ||
|
|
80306e1af3 | ||
|
|
fd999b63f1 | ||
|
|
c0f352b914 | ||
|
c4c760484b
|
|||
|
|
432e0722e0 | ||
|
|
e732e54518 | ||
|
|
05b5a2c876 | ||
|
|
4a020c56d1 | ||
|
56d365d14b
|
|||
|
|
a257d85d45 | ||
|
117675d02f
|
|||
|
|
d7ad7308c5 | ||
|
8d57bf9525
|
|||
|
|
5a08ec2f57 | ||
|
|
433dc55fcd | ||
|
|
0b2abd8c43 | ||
|
|
07c99be89b | ||
|
|
9b75cc7dfc | ||
|
|
6144b383eb | ||
|
7ec407834a
|
|||
|
|
7da34f33a2 | ||
|
8c8911fb27
|
|||
|
|
c0c3b38054 | ||
|
e17d385910
|
|||
|
|
6cf7df55ef | ||
|
|
7482bc5dad | ||
|
|
9d07dfea85 | ||
|
|
31678935aa | ||
|
|
823a0b3e36 | ||
|
1591cd3d64
|
|||
|
|
6714ceeb92 | ||
|
|
73fae04333 | ||
|
|
32ed36e102 | ||
|
|
48567310bc | ||
|
|
de51ed4675 | ||
|
|
794767edbb | ||
|
|
9c136f05bb | ||
|
3299a439fe
|
|||
|
|
d5b22a72fd | ||
|
|
c32f2e18be | ||
|
d971261f98
|
|||
|
|
74a54b7396 | ||
|
|
19805ab376 | ||
|
|
d4498e2063 | ||
| f59c1a17e2 | |||
|
|
8982ba18e3 | ||
|
|
71fe6f478f | ||
|
|
1cfbf14986 | ||
|
|
e3ff535b7e | ||
|
|
8825c772ce | ||
|
|
c8c263ca8f | ||
|
2020fa2f92
|
|||
|
|
1ea316bef4 | ||
|
|
ced1a655f2 | ||
|
|
290b2a06ec | ||
|
|
baa9711665 | ||
|
d526969d0e
|
|||
|
|
e24153053e | ||
|
348ed4c148
|
|||
|
bd6e90de1b
|
|||
|
|
4404fb3df9 | ||
|
|
f68793fbdb | ||
|
|
3a69c3c788 | ||
|
e861a0a49a
|
|||
|
|
cb2cf572e0 | ||
|
494869a172
|
|||
|
|
e0bc93096d | ||
|
1ff94eb9d3
|
|||
|
|
97ab10edcd | ||
|
|
3ff7ff18bb | ||
|
0f50c8a0f0
|
|||
|
|
691fb78fda | ||
|
|
34ef4da317 | ||
|
|
8c287b3ce7 | ||
|
54f5479c24
|
|||
|
|
f467754df1 | ||
|
b57ce40b05
|
|||
|
5264631550
|
|||
|
a76f7c439d
|
|||
|
|
d14551781c | ||
|
|
577e087321 | ||
|
aa72dc2eb5
|
|||
|
|
1a98e36909 | ||
|
ba5180a73b
|
|||
|
a9f486d905
|
|||
|
53e80cd0d5
|
|||
|
|
45001767aa | ||
|
|
cd551b6bff | ||
|
|
718a12be28 | ||
|
|
fa16bf1bff | ||
|
|
c4a227f9fc |
14
.github/dependabot.yml
vendored
Normal file
14
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
version: 2
|
||||||
|
updates:
|
||||||
|
- package-ecosystem: "github-actions"
|
||||||
|
directory: "/"
|
||||||
|
schedule:
|
||||||
|
interval: "weekly"
|
||||||
|
commit-message:
|
||||||
|
prefix: ⬆
|
||||||
|
- package-ecosystem: "uv"
|
||||||
|
directory: "/"
|
||||||
|
schedule:
|
||||||
|
interval: "weekly"
|
||||||
|
commit-message:
|
||||||
|
prefix: ⬆
|
||||||
8
.github/workflows/build-release.yml
vendored
8
.github/workflows/build-release.yml
vendored
@@ -11,16 +11,16 @@ jobs:
|
|||||||
permissions:
|
permissions:
|
||||||
id-token: write
|
id-token: write
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@v5
|
uses: astral-sh/setup-uv@v7
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
run: uv python install 3.13
|
run: uv python install 3.14
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: uv sync
|
run: uv sync --group dev
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
run: uv build
|
run: uv build
|
||||||
|
|||||||
34
.github/workflows/ci.yml
vendored
34
.github/workflows/ci.yml
vendored
@@ -15,16 +15,16 @@ jobs:
|
|||||||
name: Lint (Ruff)
|
name: Lint (Ruff)
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@v5
|
uses: astral-sh/setup-uv@v7
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
run: uv python install 3.13
|
run: uv python install 3.13
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: uv sync --extra dev
|
run: uv sync --group dev
|
||||||
|
|
||||||
- name: Run Ruff linter
|
- name: Run Ruff linter
|
||||||
run: uv run ruff check .
|
run: uv run ruff check .
|
||||||
@@ -36,16 +36,16 @@ jobs:
|
|||||||
name: Type Check (ty)
|
name: Type Check (ty)
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@v5
|
uses: astral-sh/setup-uv@v7
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
run: uv python install 3.13
|
run: uv python install 3.13
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: uv sync --extra dev
|
run: uv sync --group dev
|
||||||
|
|
||||||
- name: Run ty
|
- name: Run ty
|
||||||
run: uv run ty check
|
run: uv run ty check
|
||||||
@@ -56,7 +56,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.11", "3.12", "3.13"]
|
python-version: ["3.11", "3.12", "3.13", "3.14"]
|
||||||
|
|
||||||
services:
|
services:
|
||||||
postgres:
|
postgres:
|
||||||
@@ -74,27 +74,35 @@ jobs:
|
|||||||
--health-retries 5
|
--health-retries 5
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@v5
|
uses: astral-sh/setup-uv@v7
|
||||||
|
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
run: uv python install ${{ matrix.python-version }}
|
run: uv python install ${{ matrix.python-version }}
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: uv sync --extra dev
|
run: uv sync --group dev
|
||||||
|
|
||||||
- name: Run tests with coverage
|
- name: Run tests with coverage
|
||||||
env:
|
env:
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@localhost:5432/test_db
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@localhost:5432/test_db
|
||||||
run: |
|
run: |
|
||||||
uv run pytest --cov --cov-report=xml --cov-report=term-missing
|
uv run pytest --cov --cov-report=xml --cov-report=term-missing --junitxml=junit.xml -o junit_family=legacy
|
||||||
|
|
||||||
- name: Upload coverage to Codecov
|
- name: Upload coverage to Codecov
|
||||||
if: matrix.python-version == '3.13'
|
if: matrix.python-version == '3.14'
|
||||||
uses: codecov/codecov-action@v5
|
uses: codecov/codecov-action@v6
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
|
report_type: coverage
|
||||||
files: ./coverage.xml
|
files: ./coverage.xml
|
||||||
fail_ci_if_error: false
|
fail_ci_if_error: false
|
||||||
|
|
||||||
|
- name: Upload test results to Codecov
|
||||||
|
if: matrix.python-version == '3.14'
|
||||||
|
uses: codecov/codecov-action@v6
|
||||||
|
with:
|
||||||
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
|
report_type: test_results
|
||||||
|
|||||||
62
.github/workflows/docs.yml
vendored
Normal file
62
.github/workflows/docs.yml
vendored
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
name: Documentation
|
||||||
|
|
||||||
|
on:
|
||||||
|
release:
|
||||||
|
types: [published]
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
deploy:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v7
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
run: uv python install 3.13
|
||||||
|
|
||||||
|
- run: uv sync --group dev
|
||||||
|
|
||||||
|
- name: Configure git
|
||||||
|
run: |
|
||||||
|
git config user.name "github-actions[bot]"
|
||||||
|
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||||
|
|
||||||
|
- name: Deploy documentation
|
||||||
|
run: |
|
||||||
|
VERSION=${GITHUB_REF_NAME#v}
|
||||||
|
MAJOR=$(echo "$VERSION" | cut -d. -f1)
|
||||||
|
DEPLOY_VERSION="v$(echo "$VERSION" | cut -d. -f1-2)"
|
||||||
|
|
||||||
|
# On new major: consolidate previous major's feature versions into vX
|
||||||
|
PREV_MAJOR=$((MAJOR - 1))
|
||||||
|
OLD_FEATURE_VERSIONS=$(uv run mike list 2>/dev/null | grep -oE "^v${PREV_MAJOR}\.[0-9]+" || true)
|
||||||
|
|
||||||
|
if [ -n "$OLD_FEATURE_VERSIONS" ]; then
|
||||||
|
LATEST_PREV_TAG=$(git tag -l "v${PREV_MAJOR}.*" | sort -V | tail -1)
|
||||||
|
|
||||||
|
if [ -n "$LATEST_PREV_TAG" ]; then
|
||||||
|
git checkout "$LATEST_PREV_TAG" -- docs/ docs_src/ src/ zensical.toml
|
||||||
|
if ! grep -q '\[project\.extra\.version\]' zensical.toml; then
|
||||||
|
printf '\n[project.extra.version]\nprovider = "mike"\ndefault = "stable"\nalias = true\n' >> zensical.toml
|
||||||
|
fi
|
||||||
|
uv run mike deploy "v${PREV_MAJOR}"
|
||||||
|
git checkout HEAD -- docs/ docs_src/ src/ zensical.toml
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Delete old feature versions
|
||||||
|
echo "$OLD_FEATURE_VERSIONS" | while read -r OLD_V; do
|
||||||
|
echo "Deleting $OLD_V"
|
||||||
|
uv run mike delete "$OLD_V"
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
uv run mike deploy --update-aliases "$DEPLOY_VERSION" stable
|
||||||
|
uv run mike set-default stable
|
||||||
|
git push origin gh-pages
|
||||||
34
.pre-commit-config.yaml
Normal file
34
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
# See https://pre-commit.com for more information
|
||||||
|
# See https://pre-commit.com/hooks.html for more hooks
|
||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v6.0.0
|
||||||
|
hooks:
|
||||||
|
- id: check-added-large-files
|
||||||
|
args: ["--maxkb=750"]
|
||||||
|
exclude: ^uv.lock$
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: trailing-whitespace
|
||||||
|
|
||||||
|
- repo: local
|
||||||
|
hooks:
|
||||||
|
- id: local-ruff-check
|
||||||
|
name: ruff check
|
||||||
|
entry: uv run ruff check --force-exclude --fix --exit-non-zero-on-fix .
|
||||||
|
require_serial: true
|
||||||
|
language: unsupported
|
||||||
|
types: [python]
|
||||||
|
|
||||||
|
- id: local-ruff-format
|
||||||
|
name: ruff format
|
||||||
|
entry: uv run ruff format --force-exclude --exit-non-zero-on-format .
|
||||||
|
require_serial: true
|
||||||
|
language: unsupported
|
||||||
|
types: [python]
|
||||||
|
|
||||||
|
- id: local-ty
|
||||||
|
name: ty check
|
||||||
|
entry: uv run ty check
|
||||||
|
require_serial: true
|
||||||
|
language: unsupported
|
||||||
|
pass_filenames: false
|
||||||
@@ -1 +1 @@
|
|||||||
3.13
|
3.14
|
||||||
|
|||||||
38
README.md
38
README.md
@@ -1,6 +1,6 @@
|
|||||||
# FastAPI Toolsets
|
# FastAPI Toolsets
|
||||||
|
|
||||||
FastAPI Toolsets provides production-ready utilities for FastAPI applications built with async SQLAlchemy and PostgreSQL. It includes generic CRUD operations, a fixture system with dependency resolution, a Django-like CLI, standardized API responses, and structured exception handling with automatic OpenAPI documentation.
|
A modular collection of production-ready utilities for FastAPI. Install only what you need — from async CRUD and database helpers to CLI tooling, Prometheus metrics, and pytest fixtures. Each module is independently installable via optional extras, keeping your dependency footprint minimal.
|
||||||
|
|
||||||
[](https://github.com/d3vyce/fastapi-toolsets/actions/workflows/ci.yml)
|
[](https://github.com/d3vyce/fastapi-toolsets/actions/workflows/ci.yml)
|
||||||
[](https://codecov.io/gh/d3vyce/fastapi-toolsets)
|
[](https://codecov.io/gh/d3vyce/fastapi-toolsets)
|
||||||
@@ -20,17 +20,45 @@ FastAPI Toolsets provides production-ready utilities for FastAPI applications bu
|
|||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
|
The base package includes the core modules (CRUD, database, schemas, exceptions, fixtures, dependencies, model mixins, logging):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv add fastapi-toolsets
|
uv add fastapi-toolsets
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Install only the extras you need:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add "fastapi-toolsets[cli]"
|
||||||
|
uv add "fastapi-toolsets[metrics]"
|
||||||
|
uv add "fastapi-toolsets[pytest]"
|
||||||
|
```
|
||||||
|
|
||||||
|
Or install everything:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add "fastapi-toolsets[all]"
|
||||||
|
```
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- **CRUD**: Generic async CRUD operations with `CrudFactory`
|
### Core
|
||||||
- **Fixtures**: Fixture system with dependency management, context support and pytest integration
|
|
||||||
- **CLI**: Django-like command-line interface for fixtures and custom commands
|
- **CRUD**: Generic async CRUD operations with `CrudFactory`, built-in full-text/faceted search and Offset/Cursor pagination.
|
||||||
- **Standardized API Responses**: Consistent response format across your API
|
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
|
||||||
|
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
|
||||||
|
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration
|
||||||
|
- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`)
|
||||||
|
- **Lifecycle Events**: Post-commit event system (`EventSession`, `listens_for`) that dispatches async/sync callbacks for insert, update, and delete operations
|
||||||
|
- **Standardized API Responses**: Consistent response format with `Response`, `ErrorResponse`, `PaginatedResponse`, `CursorPaginatedResponse` and `OffsetPaginatedResponse`.
|
||||||
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
|
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
|
||||||
|
- **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger`
|
||||||
|
|
||||||
|
### Optional
|
||||||
|
|
||||||
|
- **CLI**: Django-like command-line interface with fixture management and custom commands support
|
||||||
|
- **Metrics**: Prometheus metrics endpoint with provider/collector registry
|
||||||
|
- **Pytest Helpers**: Async test client, database session management, `pytest-xdist` support, and table cleanup utilities
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
|
|||||||
185
docs/examples/pagination-search.md
Normal file
185
docs/examples/pagination-search.md
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
# Pagination & search
|
||||||
|
|
||||||
|
This example builds an articles listing endpoint that supports **offset pagination**, **cursor pagination**, **full-text search**, **faceted filtering**, and **sorting** — all from a single `CrudFactory` definition.
|
||||||
|
|
||||||
|
## Models
|
||||||
|
|
||||||
|
```python title="models.py"
|
||||||
|
--8<-- "docs_src/examples/pagination_search/models.py"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Schemas
|
||||||
|
|
||||||
|
```python title="schemas.py"
|
||||||
|
--8<-- "docs_src/examples/pagination_search/schemas.py"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Crud
|
||||||
|
|
||||||
|
Declare `searchable_fields`, `facet_fields`, and `order_fields` once on [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory). All endpoints built from this class share the same defaults and can override them per call.
|
||||||
|
|
||||||
|
```python title="crud.py"
|
||||||
|
--8<-- "docs_src/examples/pagination_search/crud.py"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Session dependency
|
||||||
|
|
||||||
|
```python title="db.py"
|
||||||
|
--8<-- "docs_src/examples/pagination_search/db.py"
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! info "Deploy a Postgres DB with docker"
|
||||||
|
```bash
|
||||||
|
docker run -d --name postgres -e POSTGRES_USER=postgres -e POSTGRES_PASSWORD=postgres -e POSTGRES_DB=postgres -p 5432:5432 postgres:18-alpine
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## App
|
||||||
|
|
||||||
|
```python title="app.py"
|
||||||
|
--8<-- "docs_src/examples/pagination_search/app.py"
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Routes
|
||||||
|
|
||||||
|
```python title="routes.py:1:16"
|
||||||
|
--8<-- "docs_src/examples/pagination_search/routes.py:1:16"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Offset pagination
|
||||||
|
|
||||||
|
Best for admin panels or any UI that needs a total item count and numbered pages.
|
||||||
|
|
||||||
|
```python title="routes.py:19:37"
|
||||||
|
--8<-- "docs_src/examples/pagination_search/routes.py:19:37"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example request**
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /articles/offset?page=2&items_per_page=10&search=fastapi&status=published&order_by=title&order=asc
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example response**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"pagination_type": "offset",
|
||||||
|
"data": [
|
||||||
|
{ "id": "3f47ac69-...", "title": "FastAPI tips", "status": "published", ... }
|
||||||
|
],
|
||||||
|
"pagination": {
|
||||||
|
"total_count": 42,
|
||||||
|
"pages": 5,
|
||||||
|
"page": 2,
|
||||||
|
"items_per_page": 10,
|
||||||
|
"has_more": true
|
||||||
|
},
|
||||||
|
"filter_attributes": {
|
||||||
|
"status": ["archived", "draft", "published"],
|
||||||
|
"name": ["backend", "frontend", "python"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
`filter_attributes` always reflects the values visible **after** applying the active filters. Use it to populate filter dropdowns on the client.
|
||||||
|
|
||||||
|
To skip the `COUNT(*)` query for better performance on large tables, pass `include_total=False`. `pagination.total_count` will be `null` in the response, while `has_more` remains accurate.
|
||||||
|
|
||||||
|
### Cursor pagination
|
||||||
|
|
||||||
|
Best for feeds, infinite scroll, or any high-throughput API where offset performance degrades.
|
||||||
|
|
||||||
|
```python title="routes.py:40:58"
|
||||||
|
--8<-- "docs_src/examples/pagination_search/routes.py:40:58"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example request**
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /articles/cursor?items_per_page=10&status=published&order_by=created_at&order=desc
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example response**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"pagination_type": "cursor",
|
||||||
|
"data": [
|
||||||
|
{ "id": "3f47ac69-...", "title": "FastAPI tips", "status": "published", ... }
|
||||||
|
],
|
||||||
|
"pagination": {
|
||||||
|
"next_cursor": "eyJ2YWx1ZSI6ICIzZjQ3YWM2OS0uLi4ifQ==",
|
||||||
|
"prev_cursor": null,
|
||||||
|
"items_per_page": 10,
|
||||||
|
"has_more": true
|
||||||
|
},
|
||||||
|
"filter_attributes": {
|
||||||
|
"status": ["published"],
|
||||||
|
"name": ["backend", "python"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Pass `next_cursor` as the `cursor` query parameter on the next request to advance to the next page.
|
||||||
|
|
||||||
|
### Unified endpoint (both strategies)
|
||||||
|
|
||||||
|
!!! info "Added in `v2.3.0`"
|
||||||
|
|
||||||
|
[`paginate()`](../module/crud.md#unified-paginate--both-strategies-on-one-endpoint) lets a single endpoint support both strategies via a `pagination_type` query parameter. The `pagination_type` field in the response acts as a discriminator for frontend tooling.
|
||||||
|
|
||||||
|
```python title="routes.py:61:79"
|
||||||
|
--8<-- "docs_src/examples/pagination_search/routes.py:61:79"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Offset request** (default)
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /articles/?pagination_type=offset&page=1&items_per_page=10
|
||||||
|
```
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"pagination_type": "offset",
|
||||||
|
"data": ["..."],
|
||||||
|
"pagination": { "total_count": 42, "pages": 5, "page": 1, "items_per_page": 10, "has_more": true }
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Cursor request**
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /articles/?pagination_type=cursor&items_per_page=10
|
||||||
|
GET /articles/?pagination_type=cursor&items_per_page=10&cursor=eyJ2YWx1ZSI6...
|
||||||
|
```
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"pagination_type": "cursor",
|
||||||
|
"data": ["..."],
|
||||||
|
"pagination": { "next_cursor": "eyJ2YWx1ZSI6...", "prev_cursor": null, "items_per_page": 10, "has_more": true }
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Search behaviour
|
||||||
|
|
||||||
|
Both endpoints inherit the same `searchable_fields` declared on `ArticleCrud`:
|
||||||
|
|
||||||
|
Search is **case-insensitive** and uses a `LIKE %query%` pattern. Pass a [`SearchConfig`](../reference/crud.md#fastapi_toolsets.crud.search.SearchConfig) instead of a plain string to control case sensitivity or switch to `match_mode="all"` (AND across all fields instead of OR).
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.crud import SearchConfig
|
||||||
|
|
||||||
|
# Both title AND body must contain "fastapi"
|
||||||
|
result = await ArticleCrud.offset_paginate(
|
||||||
|
session,
|
||||||
|
search=SearchConfig(query="fastapi", case_sensitive=True, match_mode="all"),
|
||||||
|
search_fields=[Article.title, Article.body],
|
||||||
|
)
|
||||||
|
```
|
||||||
69
docs/index.md
Normal file
69
docs/index.md
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
# FastAPI Toolsets
|
||||||
|
|
||||||
|
A modular collection of production-ready utilities for FastAPI. Install only what you need — from async CRUD and database helpers to CLI tooling, Prometheus metrics, and pytest fixtures. Each module is independently installable via optional extras, keeping your dependency footprint minimal.
|
||||||
|
|
||||||
|
[](https://github.com/d3vyce/fastapi-toolsets/actions/workflows/ci.yml)
|
||||||
|
[](https://codecov.io/gh/d3vyce/fastapi-toolsets)
|
||||||
|
[](https://github.com/astral-sh/ty)
|
||||||
|
[](https://github.com/astral-sh/uv)
|
||||||
|
[](https://github.com/astral-sh/ruff)
|
||||||
|
[](https://www.python.org/downloads/)
|
||||||
|
[](https://opensource.org/licenses/MIT)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Documentation**: [https://fastapi-toolsets.d3vyce.fr](https://fastapi-toolsets.d3vyce.fr)
|
||||||
|
|
||||||
|
**Source Code**: [https://github.com/d3vyce/fastapi-toolsets](https://github.com/d3vyce/fastapi-toolsets)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
The base package includes the core modules (CRUD, database, schemas, exceptions, fixtures, dependencies, model mixins, logging):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add fastapi-toolsets
|
||||||
|
```
|
||||||
|
|
||||||
|
Install only the extras you need:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add "fastapi-toolsets[cli]"
|
||||||
|
uv add "fastapi-toolsets[metrics]"
|
||||||
|
uv add "fastapi-toolsets[pytest]"
|
||||||
|
```
|
||||||
|
|
||||||
|
Or install everything:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add "fastapi-toolsets[all]"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
### Core
|
||||||
|
|
||||||
|
- **CRUD**: Generic async CRUD operations with `CrudFactory`, built-in full-text/faceted search and Offset/Cursor pagination.
|
||||||
|
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
|
||||||
|
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
|
||||||
|
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration
|
||||||
|
- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`).
|
||||||
|
- **Lifecycle Events**: Post-commit event system (`EventSession`, `listens_for`) that dispatches async/sync callbacks for insert, update, and delete operations.
|
||||||
|
- **Standardized API Responses**: Consistent response format with `Response`, `ErrorResponse`, `PaginatedResponse`, `CursorPaginatedResponse` and `OffsetPaginatedResponse`.
|
||||||
|
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
|
||||||
|
- **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger`
|
||||||
|
|
||||||
|
### Optional
|
||||||
|
|
||||||
|
- **CLI**: Django-like command-line interface with fixture management and custom commands support
|
||||||
|
- **Metrics**: Prometheus metrics endpoint with provider/collector registry
|
||||||
|
- **Pytest Helpers**: Async test client, database session management, `pytest-xdist` support, and table cleanup utilities
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT License - see [LICENSE](LICENSE) for details.
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Contributions are welcome! Please feel free to submit issues and pull requests.
|
||||||
137
docs/migration/v2.md
Normal file
137
docs/migration/v2.md
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
# Migrating to v2.0
|
||||||
|
|
||||||
|
This page covers every breaking change introduced in **v2.0** and the steps required to update your code.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## CRUD
|
||||||
|
|
||||||
|
### `schema` is now required in `offset_paginate()` and `cursor_paginate()`
|
||||||
|
|
||||||
|
Calls that omit `schema` will now raise a `TypeError` at runtime.
|
||||||
|
|
||||||
|
Previously `schema` was optional; omitting it returned raw SQLAlchemy model instances inside the response. It is now a required keyword argument and the response always contains serialized schema instances.
|
||||||
|
|
||||||
|
=== "Before (`v1`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
# schema omitted — returned raw model instances
|
||||||
|
result = await UserCrud.offset_paginate(session=session, page=1)
|
||||||
|
result = await UserCrud.cursor_paginate(session=session, cursor=token)
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Now (`v2`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
result = await UserCrud.offset_paginate(session=session, page=1, schema=UserRead)
|
||||||
|
result = await UserCrud.cursor_paginate(session=session, cursor=token, schema=UserRead)
|
||||||
|
```
|
||||||
|
|
||||||
|
### `as_response` removed from `create()`, `get()`, and `update()`
|
||||||
|
|
||||||
|
Passing `as_response` to these methods will raise a `TypeError` at runtime.
|
||||||
|
|
||||||
|
The `as_response=True` shorthand is replaced by passing a `schema` directly. The return value is a `Response[schema]` when `schema` is provided, or the raw model instance when it is not.
|
||||||
|
|
||||||
|
=== "Before (`v1`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
user = await UserCrud.create(session=session, obj=data, as_response=True)
|
||||||
|
user = await UserCrud.get(session=session, filters=filters, as_response=True)
|
||||||
|
user = await UserCrud.update(session=session, obj=data, filters, as_response=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Now (`v2`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
user = await UserCrud.create(session=session, obj=data, schema=UserRead)
|
||||||
|
user = await UserCrud.get(session=session, filters=filters, schema=UserRead)
|
||||||
|
user = await UserCrud.update(session=session, obj=data, filters, schema=UserRead)
|
||||||
|
```
|
||||||
|
|
||||||
|
### `delete()`: `as_response` renamed and return type changed
|
||||||
|
|
||||||
|
`as_response` is gone, and the plain (non-response) call no longer returns `True`.
|
||||||
|
|
||||||
|
Two changes were made to `delete()`:
|
||||||
|
|
||||||
|
1. The `as_response` parameter is renamed to `return_response`.
|
||||||
|
2. When called without `return_response=True`, the method now returns `None` on success instead of `True`.
|
||||||
|
|
||||||
|
=== "Before (`v1`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
ok = await UserCrud.delete(session=session, filters=filters)
|
||||||
|
if ok: # True on success
|
||||||
|
...
|
||||||
|
|
||||||
|
response = await UserCrud.delete(session=session, filters=filters, as_response=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Now (`v2`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
await UserCrud.delete(session=session, filters=filters) # returns None
|
||||||
|
|
||||||
|
response = await UserCrud.delete(session=session, filters=filters, return_response=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
### `paginate()` alias removed
|
||||||
|
|
||||||
|
Any call to `crud.paginate(...)` will raise `AttributeError` at runtime.
|
||||||
|
|
||||||
|
The `paginate` shorthand was an alias for `offset_paginate`. It has been removed; call `offset_paginate` directly.
|
||||||
|
|
||||||
|
=== "Before (`v1`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
result = await UserCrud.paginate(session=session, page=2, items_per_page=20, schema=UserRead)
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Now (`v2`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
result = await UserCrud.offset_paginate(session=session, page=2, items_per_page=20, schema=UserRead)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Exceptions
|
||||||
|
|
||||||
|
### Missing `api_error` raises `TypeError` at class definition time
|
||||||
|
|
||||||
|
Unfinished or stub exception subclasses that previously compiled fine will now fail on import.
|
||||||
|
|
||||||
|
In `v1`, a subclass without `api_error` would only fail when the exception was raised. In `v2`, `__init_subclass__` validates this at class definition time.
|
||||||
|
|
||||||
|
=== "Before (`v1`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MyError(ApiException):
|
||||||
|
pass # fine until raised
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Now (`v2`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MyError(ApiException):
|
||||||
|
pass # TypeError: MyError must define an 'api_error' class attribute.
|
||||||
|
```
|
||||||
|
|
||||||
|
For shared base classes that are not meant to be raised directly, use `abstract=True`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class BillingError(ApiException, abstract=True):
|
||||||
|
"""Base for all billing-related errors — not raised directly."""
|
||||||
|
|
||||||
|
class PaymentRequiredError(BillingError):
|
||||||
|
api_error = ApiError(code=402, msg="Payment Required", desc="...", err_code="BILLING-402")
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Schemas
|
||||||
|
|
||||||
|
### `Pagination` alias removed
|
||||||
|
|
||||||
|
`Pagination` was already deprecated in `v1` and is fully removed in `v2`, you now need to use [`OffsetPagination`](../reference/schemas.md#fastapi_toolsets.schemas.OffsetPagination) or [`CursorPagination`](../reference/schemas.md#fastapi_toolsets.schemas.CursorPagination).
|
||||||
180
docs/migration/v3.md
Normal file
180
docs/migration/v3.md
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
# Migrating to v3.0
|
||||||
|
|
||||||
|
This page covers every breaking change introduced in **v3.0** and the steps required to update your code.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## CRUD
|
||||||
|
|
||||||
|
### Facet keys now always use the full relationship chain
|
||||||
|
|
||||||
|
In `v2`, relationship facet fields used only the terminal column key (e.g. `"name"` for `Role.name`) and only prepended the relationship name when two facet fields shared the same column key. In `v3`, facet keys **always** include the full relationship chain joined by `__`, regardless of collisions.
|
||||||
|
|
||||||
|
=== "Before (`v2`)"
|
||||||
|
|
||||||
|
```
|
||||||
|
User.status -> status
|
||||||
|
(User.role, Role.name) -> name
|
||||||
|
(User.role, Role.permission, Permission.name) -> name
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Now (`v3`)"
|
||||||
|
|
||||||
|
```
|
||||||
|
User.status -> status
|
||||||
|
(User.role, Role.name) -> role__name
|
||||||
|
(User.role, Role.permission, Permission.name) -> role__permission__name
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### `*_params` dependencies consolidated into per-paginate methods
|
||||||
|
|
||||||
|
The six individual dependency methods (`offset_params`, `cursor_params`, `paginate_params`, `filter_params`, `search_params`, `order_params`) have been **removed** and replaced by three consolidated methods that bundle pagination, search, filter, and order into a single `Depends()` call.
|
||||||
|
|
||||||
|
| Removed | Replacement |
|
||||||
|
|---|---|
|
||||||
|
| `offset_params()` + `filter_params()` + `search_params()` + `order_params()` | `offset_paginate_params()` |
|
||||||
|
| `cursor_params()` + `filter_params()` + `search_params()` + `order_params()` | `cursor_paginate_params()` |
|
||||||
|
| `paginate_params()` + `filter_params()` + `search_params()` + `order_params()` | `paginate_params()` |
|
||||||
|
|
||||||
|
Each new method accepts `search`, `filter`, and `order` boolean toggles (all `True` by default) to disable features you don't need.
|
||||||
|
|
||||||
|
=== "Before (`v2`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.crud import OrderByClause
|
||||||
|
|
||||||
|
@router.get("/offset")
|
||||||
|
async def list_articles_offset(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[dict, Depends(ArticleCrud.offset_params(default_page_size=20))],
|
||||||
|
filter_by: Annotated[dict, Depends(ArticleCrud.filter_params())],
|
||||||
|
order_by: Annotated[OrderByClause | None, Depends(ArticleCrud.order_params(default_field=Article.created_at))],
|
||||||
|
search: str | None = None,
|
||||||
|
) -> OffsetPaginatedResponse[ArticleRead]:
|
||||||
|
return await ArticleCrud.offset_paginate(
|
||||||
|
session=session,
|
||||||
|
**params,
|
||||||
|
search=search,
|
||||||
|
filter_by=filter_by or None,
|
||||||
|
order_by=order_by,
|
||||||
|
schema=ArticleRead,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Now (`v3`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
@router.get("/offset")
|
||||||
|
async def list_articles_offset(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[
|
||||||
|
dict,
|
||||||
|
Depends(
|
||||||
|
ArticleCrud.offset_paginate_params(
|
||||||
|
default_page_size=20,
|
||||||
|
default_order_field=Article.created_at,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
],
|
||||||
|
) -> OffsetPaginatedResponse[ArticleRead]:
|
||||||
|
return await ArticleCrud.offset_paginate(session=session, **params, schema=ArticleRead)
|
||||||
|
```
|
||||||
|
|
||||||
|
The same pattern applies to `cursor_paginate_params()` and `paginate_params()`. To disable a feature, pass the toggle:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# No search or ordering, only pagination + filtering
|
||||||
|
ArticleCrud.offset_paginate_params(search=False, order=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Models
|
||||||
|
|
||||||
|
The lifecycle event system has been rewritten. Callbacks are now registered with a module-level [`listens_for`](../reference/models.md#fastapi_toolsets.models.listens_for) decorator and dispatched by [`EventSession`](../reference/models.md#fastapi_toolsets.models.EventSession), replacing the mixin-based approach from `v2`.
|
||||||
|
|
||||||
|
### `WatchedFieldsMixin` and `@watch` removed
|
||||||
|
|
||||||
|
Importing `WatchedFieldsMixin` or `watch` will raise `ImportError`.
|
||||||
|
|
||||||
|
Model method callbacks (`on_create`, `on_delete`, `on_update`) and the `@watch` decorator are replaced by:
|
||||||
|
|
||||||
|
1. **`__watched_fields__`** — a plain class attribute to restrict which field changes trigger `UPDATE` events (replaces `@watch`).
|
||||||
|
2. **`@listens_for`** — a module-level decorator to register callbacks for one or more [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) types (replaces `on_create` / `on_delete` / `on_update` methods).
|
||||||
|
|
||||||
|
=== "Before (`v2`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.models import WatchedFieldsMixin, watch
|
||||||
|
|
||||||
|
@watch("status")
|
||||||
|
class Order(Base, UUIDMixin, WatchedFieldsMixin):
|
||||||
|
__tablename__ = "orders"
|
||||||
|
|
||||||
|
status: Mapped[str]
|
||||||
|
|
||||||
|
async def on_create(self):
|
||||||
|
await notify_new_order(self.id)
|
||||||
|
|
||||||
|
async def on_update(self, changes):
|
||||||
|
if "status" in changes:
|
||||||
|
await notify_status_change(self.id, changes["status"])
|
||||||
|
|
||||||
|
async def on_delete(self):
|
||||||
|
await notify_order_cancelled(self.id)
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Now (`v3`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.models import ModelEvent, UUIDMixin, listens_for
|
||||||
|
|
||||||
|
class Order(Base, UUIDMixin):
|
||||||
|
__tablename__ = "orders"
|
||||||
|
__watched_fields__ = ("status",)
|
||||||
|
|
||||||
|
status: Mapped[str]
|
||||||
|
|
||||||
|
@listens_for(Order, [ModelEvent.CREATE])
|
||||||
|
async def on_order_created(order: Order, event_type: ModelEvent, changes: None):
|
||||||
|
await notify_new_order(order.id)
|
||||||
|
|
||||||
|
@listens_for(Order, [ModelEvent.UPDATE])
|
||||||
|
async def on_order_updated(order: Order, event_type: ModelEvent, changes: dict):
|
||||||
|
if "status" in changes:
|
||||||
|
await notify_status_change(order.id, changes["status"])
|
||||||
|
|
||||||
|
@listens_for(Order, [ModelEvent.DELETE])
|
||||||
|
async def on_order_deleted(order: Order, event_type: ModelEvent, changes: None):
|
||||||
|
await notify_order_cancelled(order.id)
|
||||||
|
```
|
||||||
|
|
||||||
|
### `EventSession` now required
|
||||||
|
|
||||||
|
Without `EventSession`, lifecycle callbacks will silently stop firing.
|
||||||
|
|
||||||
|
Callbacks are now dispatched inside `EventSession.commit()` rather than via background tasks. Pass it as the session class when creating your session factory:
|
||||||
|
|
||||||
|
=== "Before (`v2`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
engine = create_async_engine("postgresql+asyncpg://...")
|
||||||
|
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Now (`v3`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
|
from fastapi_toolsets.models import EventSession
|
||||||
|
|
||||||
|
engine = create_async_engine("postgresql+asyncpg://...")
|
||||||
|
SessionLocal = async_sessionmaker(engine, expire_on_commit=False, class_=EventSession)
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
If you use `create_db_session` from `fastapi_toolsets.pytest`, the session already uses `EventSession` — no changes needed in tests.
|
||||||
93
docs/module/cli.md
Normal file
93
docs/module/cli.md
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
# CLI
|
||||||
|
|
||||||
|
Typer-based command-line interface for managing your FastAPI application, with built-in fixture commands integration.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
=== "uv"
|
||||||
|
``` bash
|
||||||
|
uv add "fastapi-toolsets[cli]"
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "pip"
|
||||||
|
``` bash
|
||||||
|
pip install "fastapi-toolsets[cli]"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `cli` module provides a `manager` entry point built with [Typer](https://typer.tiangolo.com/). It allow custom commands to be added in addition of the fixture commands when a [`FixtureRegistry`](../reference/fixtures.md#fastapi_toolsets.fixtures.registry.FixtureRegistry) and a database context are configured.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Configure the CLI in your `pyproject.toml`:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[tool.fastapi-toolsets]
|
||||||
|
cli = "myapp.cli:cli" # Custom Typer app
|
||||||
|
fixtures = "myapp.fixtures:registry" # FixtureRegistry instance
|
||||||
|
db_context = "myapp.db:db_context" # Async context manager for sessions
|
||||||
|
```
|
||||||
|
|
||||||
|
All fields are optional. Without configuration, the `manager` command still works but no command are available.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Manager commands
|
||||||
|
manager --help
|
||||||
|
|
||||||
|
Usage: manager [OPTIONS] COMMAND [ARGS]...
|
||||||
|
|
||||||
|
FastAPI utilities CLI.
|
||||||
|
|
||||||
|
╭─ Options ────────────────────────────────────────────────────────────────────────╮
|
||||||
|
│ --install-completion Install completion for the current shell. │
|
||||||
|
│ --show-completion Show completion for the current shell, to copy it │
|
||||||
|
│ or customize the installation. │
|
||||||
|
│ --help Show this message and exit. │
|
||||||
|
╰──────────────────────────────────────────────────────────────────────────────────╯
|
||||||
|
╭─ Commands ───────────────────────────────────────────────────────────────────────╮
|
||||||
|
│ check-db │
|
||||||
|
│ fixtures Manage database fixtures. │
|
||||||
|
╰──────────────────────────────────────────────────────────────────────────────────╯
|
||||||
|
|
||||||
|
# Fixtures commands
|
||||||
|
manager fixtures --help
|
||||||
|
|
||||||
|
Usage: manager fixtures [OPTIONS] COMMAND [ARGS]...
|
||||||
|
|
||||||
|
Manage database fixtures.
|
||||||
|
|
||||||
|
╭─ Options ────────────────────────────────────────────────────────────────────────╮
|
||||||
|
│ --help Show this message and exit. │
|
||||||
|
╰──────────────────────────────────────────────────────────────────────────────────╯
|
||||||
|
╭─ Commands ───────────────────────────────────────────────────────────────────────╮
|
||||||
|
│ list List all registered fixtures. │
|
||||||
|
│ load Load fixtures into the database. │
|
||||||
|
╰──────────────────────────────────────────────────────────────────────────────────╯
|
||||||
|
```
|
||||||
|
|
||||||
|
## Custom CLI
|
||||||
|
|
||||||
|
You can extend the CLI by providing your own Typer app. The `manager` entry point will merge your app's commands with the built-in ones:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# myapp/cli.py
|
||||||
|
import typer
|
||||||
|
|
||||||
|
cli = typer.Typer()
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
def hello():
|
||||||
|
print("Hello from my app!")
|
||||||
|
```
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[tool.fastapi-toolsets]
|
||||||
|
cli = "myapp.cli:cli"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/cli.md)
|
||||||
577
docs/module/crud.md
Normal file
577
docs/module/crud.md
Normal file
@@ -0,0 +1,577 @@
|
|||||||
|
# CRUD
|
||||||
|
|
||||||
|
Generic async CRUD operations for SQLAlchemy models with search, pagination, and many-to-many support.
|
||||||
|
|
||||||
|
!!! info
|
||||||
|
This module has been coded and tested to be compatible with PostgreSQL only.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `crud` module provides [`AsyncCrud`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud), a base class with a full suite of async database operations, and [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory), a convenience function to instantiate it for a given model.
|
||||||
|
|
||||||
|
## Creating a CRUD class
|
||||||
|
|
||||||
|
### Factory style
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.crud import CrudFactory
|
||||||
|
from myapp.models import User
|
||||||
|
|
||||||
|
UserCrud = CrudFactory(model=User)
|
||||||
|
```
|
||||||
|
|
||||||
|
[`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory) dynamically creates a class named `AsyncUserCrud` with `User` as its model. This is the most concise option for straightforward CRUD with no custom logic.
|
||||||
|
|
||||||
|
### Subclass style
|
||||||
|
|
||||||
|
!!! info "Added in `v2.3.0`"
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.crud.factory import AsyncCrud
|
||||||
|
from myapp.models import User
|
||||||
|
|
||||||
|
class UserCrud(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
searchable_fields = [User.username, User.email]
|
||||||
|
default_load_options = [selectinload(User.role)]
|
||||||
|
```
|
||||||
|
|
||||||
|
Subclassing [`AsyncCrud`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud) directly is the preferred style when you need to add custom methods or when the configuration is complex enough to benefit from a named class body.
|
||||||
|
|
||||||
|
### Adding custom methods
|
||||||
|
|
||||||
|
```python
|
||||||
|
class UserCrud(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_active(cls, session: AsyncSession) -> list[User]:
|
||||||
|
return await cls.get_multi(session, filters=[User.is_active == True])
|
||||||
|
```
|
||||||
|
|
||||||
|
### Sharing a custom base across multiple models
|
||||||
|
|
||||||
|
Define a generic base class with the shared methods, then subclass it for each model:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
from fastapi_toolsets.crud.factory import AsyncCrud
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=DeclarativeBase)
|
||||||
|
|
||||||
|
class AuditedCrud(AsyncCrud[T], Generic[T]):
|
||||||
|
"""Base CRUD with custom function"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_active(cls, session: AsyncSession):
|
||||||
|
return await cls.get_multi(session, filters=[cls.model.is_active == True])
|
||||||
|
|
||||||
|
|
||||||
|
class UserCrud(AuditedCrud[User]):
|
||||||
|
model = User
|
||||||
|
searchable_fields = [User.username, User.email]
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also use the factory shorthand with the same base by passing `base_class`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
UserCrud = CrudFactory(User, base_class=AuditedCrud)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Basic operations
|
||||||
|
|
||||||
|
!!! info "`get_or_none` added in `v2.2`"
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Create
|
||||||
|
user = await UserCrud.create(session=session, obj=UserCreateSchema(username="alice"))
|
||||||
|
|
||||||
|
# Get one (raises NotFoundError if not found)
|
||||||
|
user = await UserCrud.get(session=session, filters=[User.id == user_id])
|
||||||
|
|
||||||
|
# Get one or None (never raises)
|
||||||
|
user = await UserCrud.get_or_none(session=session, filters=[User.id == user_id])
|
||||||
|
|
||||||
|
# Get first or None
|
||||||
|
user = await UserCrud.first(session=session, filters=[User.email == email])
|
||||||
|
|
||||||
|
# Get multiple
|
||||||
|
users = await UserCrud.get_multi(session=session, filters=[User.is_active == True])
|
||||||
|
|
||||||
|
# Update
|
||||||
|
user = await UserCrud.update(session=session, obj=UserUpdateSchema(username="bob"), filters=[User.id == user_id])
|
||||||
|
|
||||||
|
# Delete
|
||||||
|
await UserCrud.delete(session=session, filters=[User.id == user_id])
|
||||||
|
|
||||||
|
# Count / exists
|
||||||
|
count = await UserCrud.count(session=session, filters=[User.is_active == True])
|
||||||
|
exists = await UserCrud.exists(session=session, filters=[User.email == email])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Fetching a single record
|
||||||
|
|
||||||
|
Three methods fetch a single record — choose based on how you want to handle the "not found" case and whether you need strict uniqueness:
|
||||||
|
|
||||||
|
| Method | Not found | Multiple results |
|
||||||
|
|---|---|---|
|
||||||
|
| `get` | raises `NotFoundError` | raises `MultipleResultsFound` |
|
||||||
|
| `get_or_none` | returns `None` | raises `MultipleResultsFound` |
|
||||||
|
| `first` | returns `None` | returns the first match silently |
|
||||||
|
|
||||||
|
Use `get` when the record must exist (e.g. a detail endpoint that should return 404):
|
||||||
|
|
||||||
|
```python
|
||||||
|
user = await UserCrud.get(session=session, filters=[User.id == user_id])
|
||||||
|
```
|
||||||
|
|
||||||
|
Use `get_or_none` when the record may not exist but you still want strict uniqueness enforcement:
|
||||||
|
|
||||||
|
```python
|
||||||
|
user = await UserCrud.get_or_none(session=session, filters=[User.email == email])
|
||||||
|
if user is None:
|
||||||
|
... # handle missing case without catching an exception
|
||||||
|
```
|
||||||
|
|
||||||
|
Use `first` when you only care about any one match and don't need uniqueness:
|
||||||
|
|
||||||
|
```python
|
||||||
|
user = await UserCrud.first(session=session, filters=[User.is_active == True])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pagination
|
||||||
|
|
||||||
|
!!! info "Added in `v1.1` (only offset_pagination via `paginate` if `<v1.1`)"
|
||||||
|
|
||||||
|
Three pagination methods are available. All return a typed response whose `pagination_type` field tells clients which strategy was used.
|
||||||
|
|
||||||
|
| | `offset_paginate` | `cursor_paginate` | `paginate` |
|
||||||
|
|---|---|---|---|
|
||||||
|
| Return type | `OffsetPaginatedResponse` | `CursorPaginatedResponse` | either, based on `pagination_type` param |
|
||||||
|
| Total count | Yes | No | / |
|
||||||
|
| Jump to arbitrary page | Yes | No | / |
|
||||||
|
| Performance on deep pages | Degrades | Constant | / |
|
||||||
|
| Stable under concurrent inserts | No | Yes | / |
|
||||||
|
| Use case | Admin panels, numbered pagination | Feeds, APIs, infinite scroll | single endpoint, both strategies |
|
||||||
|
|
||||||
|
### Offset pagination
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import Annotated
|
||||||
|
from fastapi import Depends
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def get_users(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[dict, Depends(UserCrud.offset_paginate_params())],
|
||||||
|
) -> OffsetPaginatedResponse[UserRead]:
|
||||||
|
return await UserCrud.offset_paginate(session=session, **params, schema=UserRead)
|
||||||
|
```
|
||||||
|
|
||||||
|
The [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) method returns an [`OffsetPaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.OffsetPaginatedResponse):
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"pagination_type": "offset",
|
||||||
|
"data": ["..."],
|
||||||
|
"pagination": {
|
||||||
|
"total_count": 100,
|
||||||
|
"pages": 5,
|
||||||
|
"page": 1,
|
||||||
|
"items_per_page": 20,
|
||||||
|
"has_more": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Skipping the COUNT query
|
||||||
|
|
||||||
|
!!! info "Added in `v2.4.1`"
|
||||||
|
|
||||||
|
By default `offset_paginate` runs two queries: one for the page items and one `COUNT(*)` for `total_count`. On large tables the `COUNT` can be expensive. Pass `include_total=False` to `offset_paginate_params()` to skip it:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@router.get("")
|
||||||
|
async def get_users(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[dict, Depends(UserCrud.offset_paginate_params(include_total=False))],
|
||||||
|
) -> OffsetPaginatedResponse[UserRead]:
|
||||||
|
return await UserCrud.offset_paginate(session=session, **params, schema=UserRead)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Cursor pagination
|
||||||
|
|
||||||
|
```python
|
||||||
|
@router.get("")
|
||||||
|
async def list_users(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[dict, Depends(UserCrud.cursor_paginate_params())],
|
||||||
|
) -> CursorPaginatedResponse[UserRead]:
|
||||||
|
return await UserCrud.cursor_paginate(session=session, **params, schema=UserRead)
|
||||||
|
```
|
||||||
|
|
||||||
|
The [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate) method returns a [`CursorPaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.CursorPaginatedResponse):
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"pagination_type": "cursor",
|
||||||
|
"data": ["..."],
|
||||||
|
"pagination": {
|
||||||
|
"next_cursor": "eyJ2YWx1ZSI6ICIzZjQ3YWM2OS0uLi4ifQ==",
|
||||||
|
"prev_cursor": null,
|
||||||
|
"items_per_page": 20,
|
||||||
|
"has_more": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Pass `next_cursor` as the `cursor` query parameter on the next request to advance to the next page. `prev_cursor` is set on pages 2+ and points back to the first item of the current page. Both are `null` when there is no adjacent page.
|
||||||
|
|
||||||
|
#### Choosing a cursor column
|
||||||
|
|
||||||
|
The cursor column is set once on [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory) via the `cursor_column` parameter. It must be monotonically ordered for stable results:
|
||||||
|
|
||||||
|
- Auto-increment integer PKs
|
||||||
|
- UUID v7 PKs
|
||||||
|
- Timestamps
|
||||||
|
|
||||||
|
!!! warning
|
||||||
|
Random UUID v4 PKs are **not** suitable as cursor columns because their ordering is non-deterministic.
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
`cursor_column` is required. Calling [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate) on a CRUD class that has no `cursor_column` configured raises a `ValueError`.
|
||||||
|
|
||||||
|
The cursor value is URL-safe base64-encoded (no padding) when returned to the client and decoded back to the correct Python type on the next request. The following SQLAlchemy column types are supported:
|
||||||
|
|
||||||
|
| SQLAlchemy type | Python type |
|
||||||
|
|---|---|
|
||||||
|
| `Integer`, `BigInteger`, `SmallInteger` | `int` |
|
||||||
|
| `Uuid` | `uuid.UUID` |
|
||||||
|
| `DateTime` | `datetime.datetime` |
|
||||||
|
| `Date` | `datetime.date` |
|
||||||
|
| `Float`, `Numeric` | `decimal.Decimal` |
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Paginate by the primary key
|
||||||
|
PostCrud = CrudFactory(model=Post, cursor_column=Post.id)
|
||||||
|
|
||||||
|
# Paginate by a timestamp column instead
|
||||||
|
PostCrud = CrudFactory(model=Post, cursor_column=Post.created_at)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Unified endpoint (both strategies)
|
||||||
|
|
||||||
|
!!! info "Added in `v2.3.0`"
|
||||||
|
|
||||||
|
[`paginate()`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.paginate) dispatches to `offset_paginate` or `cursor_paginate` based on a `pagination_type` query parameter, letting you expose **one endpoint** that supports both strategies. The `pagination_type` field in the response tells clients which strategy was used, enabling frontend discriminated-union typing.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.schemas import PaginatedResponse
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def list_users(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[dict, Depends(UserCrud.paginate_params())],
|
||||||
|
) -> PaginatedResponse[UserRead]:
|
||||||
|
return await UserCrud.paginate(session, **params, schema=UserRead)
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /users?pagination_type=offset&page=2&items_per_page=10
|
||||||
|
GET /users?pagination_type=cursor&cursor=eyJ2YWx1ZSI6...&items_per_page=10
|
||||||
|
```
|
||||||
|
|
||||||
|
## Search
|
||||||
|
|
||||||
|
Two search strategies are available, both compatible with [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) and [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate).
|
||||||
|
|
||||||
|
| | Full-text search | Faceted search |
|
||||||
|
|---|---|---|
|
||||||
|
| Input | Free-text string | Exact column values |
|
||||||
|
| Relationship support | Yes | Yes |
|
||||||
|
| Use case | Search bars | Filter dropdowns |
|
||||||
|
|
||||||
|
!!! info "You can use both search strategies in the same endpoint!"
|
||||||
|
|
||||||
|
### Full-text search
|
||||||
|
|
||||||
|
!!! info "Added in `v2.2.1`"
|
||||||
|
The model's primary key is always included in `searchable_fields` automatically, so searching by ID works out of the box without any configuration. When no `searchable_fields` are declared, only the primary key is searched.
|
||||||
|
|
||||||
|
Declare `searchable_fields` on the CRUD class. Relationship traversal is supported via tuples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
PostCrud = CrudFactory(
|
||||||
|
model=Post,
|
||||||
|
searchable_fields=[
|
||||||
|
Post.title,
|
||||||
|
Post.content,
|
||||||
|
(Post.author, User.username), # search across relationship
|
||||||
|
],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
You can override `searchable_fields` per call with `search_fields`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
result = await UserCrud.offset_paginate(
|
||||||
|
session=session,
|
||||||
|
search_fields=[User.country],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
This allows searching with both [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) and [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate):
|
||||||
|
|
||||||
|
```python
|
||||||
|
@router.get("")
|
||||||
|
async def get_users(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[dict, Depends(UserCrud.offset_paginate_params())],
|
||||||
|
) -> OffsetPaginatedResponse[UserRead]:
|
||||||
|
return await UserCrud.offset_paginate(session=session, **params, schema=UserRead)
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
@router.get("")
|
||||||
|
async def get_users(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[dict, Depends(UserCrud.cursor_paginate_params())],
|
||||||
|
) -> CursorPaginatedResponse[UserRead]:
|
||||||
|
return await UserCrud.cursor_paginate(session=session, **params, schema=UserRead)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Faceted search
|
||||||
|
|
||||||
|
!!! info "Added in `v1.2`"
|
||||||
|
|
||||||
|
Declare `facet_fields` on the CRUD class to return distinct column values alongside paginated results. This is useful for populating filter dropdowns or building faceted search UIs.
|
||||||
|
|
||||||
|
Facet fields use the same syntax as `searchable_fields` — direct columns or relationship tuples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
UserCrud = CrudFactory(
|
||||||
|
model=User,
|
||||||
|
facet_fields=[
|
||||||
|
User.status,
|
||||||
|
User.country,
|
||||||
|
(User.role, Role.name), # value from a related model
|
||||||
|
],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
You can override `facet_fields` per call:
|
||||||
|
|
||||||
|
```python
|
||||||
|
result = await UserCrud.offset_paginate(
|
||||||
|
session=session,
|
||||||
|
facet_fields=[User.country],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
The distinct values are returned in the `filter_attributes` field of [`PaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse):
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"data": ["..."],
|
||||||
|
"pagination": { "..." },
|
||||||
|
"filter_attributes": {
|
||||||
|
"status": ["active", "inactive"],
|
||||||
|
"country": ["DE", "FR", "US"],
|
||||||
|
"role__name": ["admin", "editor", "viewer"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Use `filter_by` to pass the client's chosen filter values directly — no need to build SQLAlchemy conditions by hand. Any unknown key raises [`InvalidFacetFilterError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError).
|
||||||
|
|
||||||
|
!!! info "The keys in `filter_by` are the same keys the client received in `filter_attributes`."
|
||||||
|
Keys use `__` as a separator for the full relationship chain. A direct column `User.status` produces `"status"`. A relationship tuple `(User.role, Role.name)` produces `"role__name"`. A deeper chain `(User.role, Role.permission, Permission.name)` produces `"role__permission__name"`.
|
||||||
|
|
||||||
|
`filter_by` and `filters` can be combined — both are applied with AND logic.
|
||||||
|
|
||||||
|
Facet filtering is built into the consolidated params dependencies. When `filter=True` (the default), facet fields are exposed as query parameters and collected into `filter_by` automatically:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
|
||||||
|
UserCrud = CrudFactory(
|
||||||
|
model=User,
|
||||||
|
facet_fields=[User.status, User.country, (User.role, Role.name)],
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("", response_model_exclude_none=True)
|
||||||
|
async def list_users(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[dict, Depends(UserCrud.offset_paginate_params())],
|
||||||
|
) -> OffsetPaginatedResponse[UserRead]:
|
||||||
|
return await UserCrud.offset_paginate(
|
||||||
|
session=session,
|
||||||
|
**params,
|
||||||
|
schema=UserRead,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Both single-value and multi-value query parameters work:
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /users?status=active → filter_by={"status": ["active"]}
|
||||||
|
GET /users?status=active&country=FR → filter_by={"status": ["active"], "country": ["FR"]}
|
||||||
|
GET /users?role__name=admin&role__name=editor → filter_by={"role__name": ["admin", "editor"]} (IN clause)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Sorting
|
||||||
|
|
||||||
|
!!! info "Added in `v1.3`"
|
||||||
|
|
||||||
|
Declare `order_fields` on the CRUD class to expose client-driven column ordering via `order_by` and `order` query parameters.
|
||||||
|
|
||||||
|
```python
|
||||||
|
UserCrud = CrudFactory(
|
||||||
|
model=User,
|
||||||
|
order_fields=[
|
||||||
|
User.name,
|
||||||
|
User.created_at,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Ordering is built into the consolidated params dependencies. When `order=True` (the default), `order_by` and `order` query parameters are exposed and resolved into an `OrderByClause` automatically:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def list_users(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[dict, Depends(UserCrud.offset_paginate_params(
|
||||||
|
default_order_field=User.created_at,
|
||||||
|
))],
|
||||||
|
) -> OffsetPaginatedResponse[UserRead]:
|
||||||
|
return await UserCrud.offset_paginate(session=session, **params, schema=UserRead)
|
||||||
|
```
|
||||||
|
|
||||||
|
The dependency adds two query parameters to the endpoint:
|
||||||
|
|
||||||
|
| Parameter | Type |
|
||||||
|
| ---------- | --------------- |
|
||||||
|
| `order_by` | `str | null` |
|
||||||
|
| `order` | `asc` or `desc` |
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /users?order_by=name&order=asc → ORDER BY users.name ASC
|
||||||
|
GET /users?order_by=name&order=desc → ORDER BY users.name DESC
|
||||||
|
```
|
||||||
|
|
||||||
|
An unknown `order_by` value raises [`InvalidOrderFieldError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidOrderFieldError) (HTTP 422).
|
||||||
|
|
||||||
|
You can also pass `order_fields` directly to override the class-level defaults:
|
||||||
|
|
||||||
|
```python
|
||||||
|
params = UserCrud.offset_paginate_params(order_fields=[User.name])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Relationship loading
|
||||||
|
|
||||||
|
!!! info "Added in `v1.1`"
|
||||||
|
|
||||||
|
By default, SQLAlchemy relationships are not loaded unless explicitly requested. Instead of using `lazy="selectin"` on model definitions (which is implicit and applies globally), define a `default_load_options` on the CRUD class to control loading strategy explicitly.
|
||||||
|
|
||||||
|
!!! warning
|
||||||
|
Avoid using `lazy="selectin"` on model relationships. It fires silently on every query, cannot be disabled per-call, and can cause unexpected cascading loads through deep relationship chains. Use `default_load_options` instead.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
ArticleCrud = CrudFactory(
|
||||||
|
model=Article,
|
||||||
|
default_load_options=[
|
||||||
|
selectinload(Article.category),
|
||||||
|
selectinload(Article.tags),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
`default_load_options` applies automatically to all read operations (`get`, `first`, `get_multi`, `offset_paginate`, `cursor_paginate`). When `load_options` is passed at call-site, it **fully replaces** `default_load_options` for that query — giving you precise per-call control:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Only loads category, tags are not loaded
|
||||||
|
article = await ArticleCrud.get(
|
||||||
|
session=session,
|
||||||
|
filters=[Article.id == article_id],
|
||||||
|
load_options=[selectinload(Article.category)],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Loads nothing — useful for write-then-refresh flows or lightweight checks
|
||||||
|
articles = await ArticleCrud.get_multi(session=session, load_options=[])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Many-to-many relationships
|
||||||
|
|
||||||
|
Use `m2m_fields` to map schema fields containing lists of IDs to SQLAlchemy relationships. The CRUD class resolves and validates all IDs before persisting:
|
||||||
|
|
||||||
|
```python
|
||||||
|
PostCrud = CrudFactory(
|
||||||
|
model=Post,
|
||||||
|
m2m_fields={"tag_ids": Post.tags},
|
||||||
|
)
|
||||||
|
|
||||||
|
post = await PostCrud.create(session=session, obj=PostCreateSchema(title="Hello", tag_ids=[1, 2, 3]))
|
||||||
|
```
|
||||||
|
|
||||||
|
## Upsert
|
||||||
|
|
||||||
|
Atomic `INSERT ... ON CONFLICT DO UPDATE` using PostgreSQL:
|
||||||
|
|
||||||
|
```python
|
||||||
|
await UserCrud.upsert(
|
||||||
|
session=session,
|
||||||
|
obj=UserCreateSchema(email="alice@example.com", username="alice"),
|
||||||
|
index_elements=[User.email],
|
||||||
|
set_={"username"},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Response serialization
|
||||||
|
|
||||||
|
!!! info "Added in `v1.1`"
|
||||||
|
|
||||||
|
Pass a Pydantic schema class to `create`, `get`, `update`, or `offset_paginate` to serialize the result directly into that schema and wrap it in a [`Response[schema]`](../reference/schemas.md#fastapi_toolsets.schemas.Response) or [`PaginatedResponse[schema]`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse):
|
||||||
|
|
||||||
|
```python
|
||||||
|
class UserRead(PydanticBase):
|
||||||
|
id: UUID
|
||||||
|
username: str
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/{uuid}",
|
||||||
|
responses=generate_error_responses(NotFoundError),
|
||||||
|
)
|
||||||
|
async def get_user(session: SessionDep, uuid: UUID) -> Response[UserRead]:
|
||||||
|
return await crud.UserCrud.get(
|
||||||
|
session=session,
|
||||||
|
filters=[User.id == uuid],
|
||||||
|
schema=UserRead,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def list_users(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[dict, Depends(crud.UserCrud.offset_paginate_params())],
|
||||||
|
) -> OffsetPaginatedResponse[UserRead]:
|
||||||
|
return await crud.UserCrud.offset_paginate(session=session, **params, schema=UserRead)
|
||||||
|
```
|
||||||
|
|
||||||
|
The schema must have `from_attributes=True` (or inherit from [`PydanticBase`](../reference/schemas.md#fastapi_toolsets.schemas.PydanticBase)) so it can be built from SQLAlchemy model instances.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/crud.md)
|
||||||
123
docs/module/db.md
Normal file
123
docs/module/db.md
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
# DB
|
||||||
|
|
||||||
|
SQLAlchemy async session management with transactions, table locking, and row-change polling.
|
||||||
|
|
||||||
|
!!! info
|
||||||
|
This module has been coded and tested to be compatible with PostgreSQL only.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `db` module provides helpers to create FastAPI dependencies and context managers for `AsyncSession`, along with utilities for nested transactions, table lock and polling for row changes.
|
||||||
|
|
||||||
|
## Session dependency
|
||||||
|
|
||||||
|
Use [`create_db_dependency`](../reference/db.md#fastapi_toolsets.db.create_db_dependency) to create a FastAPI dependency that yields a session and auto-commits on success:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||||
|
from fastapi_toolsets.db import create_db_dependency
|
||||||
|
|
||||||
|
engine = create_async_engine(url="postgresql+asyncpg://...", future=True)
|
||||||
|
session_maker = async_sessionmaker(bind=engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
get_db = create_db_dependency(session_maker=session_maker)
|
||||||
|
|
||||||
|
@router.get("/users")
|
||||||
|
async def list_users(session: AsyncSession = Depends(get_db)):
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Session context manager
|
||||||
|
|
||||||
|
Use [`create_db_context`](../reference/db.md#fastapi_toolsets.db.create_db_context) for sessions outside request handlers (e.g. background tasks, CLI commands):
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import create_db_context
|
||||||
|
|
||||||
|
db_context = create_db_context(session_maker=session_maker)
|
||||||
|
|
||||||
|
async def seed():
|
||||||
|
async with db_context() as session:
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Nested transactions
|
||||||
|
|
||||||
|
[`get_transaction`](../reference/db.md#fastapi_toolsets.db.get_transaction) handles savepoints automatically, allowing safe nesting:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import get_transaction
|
||||||
|
|
||||||
|
async def create_user_with_role(session=session):
|
||||||
|
async with get_transaction(session=session):
|
||||||
|
...
|
||||||
|
async with get_transaction(session=session): # uses savepoint
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Table locking
|
||||||
|
|
||||||
|
[`lock_tables`](../reference/db.md#fastapi_toolsets.db.lock_tables) acquires PostgreSQL table-level locks before executing critical sections:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import lock_tables
|
||||||
|
|
||||||
|
async with lock_tables(session=session, tables=[User], mode="EXCLUSIVE"):
|
||||||
|
# No other transaction can modify User until this block exits
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
Available lock modes are defined in [`LockMode`](../reference/db.md#fastapi_toolsets.db.LockMode): `ACCESS_SHARE`, `ROW_SHARE`, `ROW_EXCLUSIVE`, `SHARE_UPDATE_EXCLUSIVE`, `SHARE`, `SHARE_ROW_EXCLUSIVE`, `EXCLUSIVE`, `ACCESS_EXCLUSIVE`.
|
||||||
|
|
||||||
|
## Row-change polling
|
||||||
|
|
||||||
|
[`wait_for_row_change`](../reference/db.md#fastapi_toolsets.db.wait_for_row_change) polls a row until a specific column changes value, useful for waiting on async side effects:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import wait_for_row_change
|
||||||
|
|
||||||
|
# Wait up to 30s for order.status to change
|
||||||
|
await wait_for_row_change(
|
||||||
|
session=session,
|
||||||
|
model=Order,
|
||||||
|
pk_value=order_id,
|
||||||
|
columns=[Order.status],
|
||||||
|
interval=1.0,
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Creating a database
|
||||||
|
|
||||||
|
!!! info "Added in `v2.1`"
|
||||||
|
|
||||||
|
[`create_database`](../reference/db.md#fastapi_toolsets.db.create_database) creates a database at a given URL. It connects to *server_url* and issues a `CREATE DATABASE` statement:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import create_database
|
||||||
|
|
||||||
|
SERVER_URL = "postgresql+asyncpg://postgres:postgres@localhost/postgres"
|
||||||
|
|
||||||
|
await create_database(db_name="myapp_test", server_url=SERVER_URL)
|
||||||
|
```
|
||||||
|
|
||||||
|
For test isolation with automatic cleanup, use [`create_worker_database`](../reference/pytest.md#fastapi_toolsets.pytest.utils.create_worker_database) from the `pytest` module instead — it handles drop-before, create, and drop-after automatically.
|
||||||
|
|
||||||
|
## Cleaning up tables
|
||||||
|
|
||||||
|
!!! info "Added in `v2.1`"
|
||||||
|
|
||||||
|
[`cleanup_tables`](../reference/db.md#fastapi_toolsets.db.cleanup_tables) truncates all tables:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import cleanup_tables
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
async def clean(db_session):
|
||||||
|
yield
|
||||||
|
await cleanup_tables(session=db_session, base=Base)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/db.md)
|
||||||
61
docs/module/dependencies.md
Normal file
61
docs/module/dependencies.md
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
# Dependencies
|
||||||
|
|
||||||
|
FastAPI dependency factories for automatic model resolution from path and body parameters.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `dependencies` module provides two factory functions that create FastAPI dependencies to fetch a model instance from the database automatically — either from a path parameter or from a request body field — and inject it directly into your route handler.
|
||||||
|
|
||||||
|
## `PathDependency`
|
||||||
|
|
||||||
|
[`PathDependency`](../reference/dependencies.md#fastapi_toolsets.dependencies.PathDependency) resolves a model from a URL path parameter and injects it into the route handler. Raises [`NotFoundError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NotFoundError) automatically if the record does not exist.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.dependencies import PathDependency
|
||||||
|
|
||||||
|
# Plain callable
|
||||||
|
UserDep = PathDependency(model=User, field=User.id, session_dep=get_db)
|
||||||
|
|
||||||
|
# Annotated
|
||||||
|
SessionDep = Annotated[AsyncSession, Depends(get_db)]
|
||||||
|
UserDep = PathDependency(model=User, field=User.id, session_dep=SessionDep)
|
||||||
|
|
||||||
|
@router.get("/users/{user_id}")
|
||||||
|
async def get_user(user: User = UserDep):
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
By default the parameter name is inferred from the field (`user_id` for `User.id`). You can override it:
|
||||||
|
|
||||||
|
```python
|
||||||
|
UserDep = PathDependency(model=User, field=User.id, session_dep=get_db, param_name="id")
|
||||||
|
|
||||||
|
@router.get("/users/{id}")
|
||||||
|
async def get_user(user: User = UserDep):
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
## `BodyDependency`
|
||||||
|
|
||||||
|
[`BodyDependency`](../reference/dependencies.md#fastapi_toolsets.dependencies.BodyDependency) resolves a model from a field in the request body. Useful when a body contains a foreign key and you want the full object injected:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.dependencies import BodyDependency
|
||||||
|
|
||||||
|
# Plain callable
|
||||||
|
RoleDep = BodyDependency(model=Role, field=Role.id, session_dep=get_db, body_field="role_id")
|
||||||
|
|
||||||
|
# Annotated
|
||||||
|
SessionDep = Annotated[AsyncSession, Depends(get_db)]
|
||||||
|
RoleDep = BodyDependency(model=Role, field=Role.id, session_dep=SessionDep, body_field="role_id")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/users")
|
||||||
|
async def create_user(body: UserCreateSchema, role: Role = RoleDep):
|
||||||
|
user = User(username=body.username, role=role)
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/dependencies.md)
|
||||||
131
docs/module/exceptions.md
Normal file
131
docs/module/exceptions.md
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
# Exceptions
|
||||||
|
|
||||||
|
Structured API exceptions with consistent error responses and automatic OpenAPI documentation.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `exceptions` module provides a set of pre-built HTTP exceptions and a FastAPI exception handler that formats all errors — including validation errors — into a uniform [`ErrorResponse`](../reference/schemas.md#fastapi_toolsets.schemas.ErrorResponse).
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
Register the exception handlers on your FastAPI app at startup:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app=app)
|
||||||
|
```
|
||||||
|
|
||||||
|
This registers handlers for:
|
||||||
|
|
||||||
|
- [`ApiException`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ApiException) — all custom exceptions below
|
||||||
|
- `HTTPException` — Starlette/FastAPI HTTP errors
|
||||||
|
- `RequestValidationError` — Pydantic request validation (422)
|
||||||
|
- `ResponseValidationError` — Pydantic response validation (422)
|
||||||
|
- `Exception` — unhandled errors (500)
|
||||||
|
|
||||||
|
It also patches `app.openapi()` to replace the default Pydantic 422 schema with a structured example matching the `ErrorResponse` format.
|
||||||
|
|
||||||
|
## Built-in exceptions
|
||||||
|
|
||||||
|
| Exception | Status | Default message |
|
||||||
|
|-----------|--------|-----------------|
|
||||||
|
| [`UnauthorizedError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.UnauthorizedError) | 401 | Unauthorized |
|
||||||
|
| [`ForbiddenError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ForbiddenError) | 403 | Forbidden |
|
||||||
|
| [`NotFoundError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NotFoundError) | 404 | Not Found |
|
||||||
|
| [`ConflictError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ConflictError) | 409 | Conflict |
|
||||||
|
| [`NoSearchableFieldsError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NoSearchableFieldsError) | 400 | No Searchable Fields |
|
||||||
|
| [`InvalidFacetFilterError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError) | 400 | Invalid Facet Filter |
|
||||||
|
| [`InvalidOrderFieldError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidOrderFieldError) | 422 | Invalid Order Field |
|
||||||
|
|
||||||
|
### Per-instance overrides
|
||||||
|
|
||||||
|
All built-in exceptions accept optional keyword arguments to customise the response for a specific raise site without changing the class defaults:
|
||||||
|
|
||||||
|
| Argument | Effect |
|
||||||
|
|----------|--------|
|
||||||
|
| `detail` | Overrides both `str(exc)` (log output) and the `message` field in the response body |
|
||||||
|
| `desc` | Overrides the `description` field |
|
||||||
|
| `data` | Overrides the `data` field |
|
||||||
|
|
||||||
|
```python
|
||||||
|
raise NotFoundError(detail="User 42 not found", desc="No user with that ID exists in the database.")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Custom exceptions
|
||||||
|
|
||||||
|
Subclass [`ApiException`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ApiException) and define an `api_error` class variable:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.exceptions import ApiException
|
||||||
|
from fastapi_toolsets.schemas import ApiError
|
||||||
|
|
||||||
|
class PaymentRequiredError(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=402,
|
||||||
|
msg="Payment Required",
|
||||||
|
desc="Your subscription has expired.",
|
||||||
|
err_code="BILLING-402",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! warning
|
||||||
|
Subclasses that do not define `api_error` raise a `TypeError` at **class creation time**, not at raise time.
|
||||||
|
|
||||||
|
### Custom `__init__`
|
||||||
|
|
||||||
|
Override `__init__` to compute `detail`, `desc`, or `data` dynamically, then delegate to `super().__init__()`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class OrderValidationError(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=422,
|
||||||
|
msg="Order Validation Failed",
|
||||||
|
desc="One or more order fields are invalid.",
|
||||||
|
err_code="ORDER-422",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, *field_errors: str) -> None:
|
||||||
|
super().__init__(
|
||||||
|
f"{len(field_errors)} validation error(s)",
|
||||||
|
desc=", ".join(field_errors),
|
||||||
|
data={"errors": [{"message": e} for e in field_errors]},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Intermediate base classes
|
||||||
|
|
||||||
|
Use `abstract=True` when creating a shared base that is not meant to be raised directly:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class BillingError(ApiException, abstract=True):
|
||||||
|
"""Base for all billing-related errors."""
|
||||||
|
|
||||||
|
class PaymentRequiredError(BillingError):
|
||||||
|
api_error = ApiError(code=402, msg="Payment Required", desc="...", err_code="BILLING-402")
|
||||||
|
|
||||||
|
class SubscriptionExpiredError(BillingError):
|
||||||
|
api_error = ApiError(code=402, msg="Subscription Expired", desc="...", err_code="BILLING-402-EXP")
|
||||||
|
```
|
||||||
|
|
||||||
|
## OpenAPI response documentation
|
||||||
|
|
||||||
|
Use [`generate_error_responses`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.generate_error_responses) to add error schemas to your endpoint's OpenAPI spec:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.exceptions import generate_error_responses, NotFoundError, ForbiddenError
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/users/{id}",
|
||||||
|
responses=generate_error_responses(NotFoundError, ForbiddenError),
|
||||||
|
)
|
||||||
|
async def get_user(...): ...
|
||||||
|
```
|
||||||
|
|
||||||
|
Multiple exceptions sharing the same HTTP status code are grouped under one entry, each appearing as a named example keyed by its `err_code`. This keeps the OpenAPI UI readable when several error variants map to the same status.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/exceptions.md)
|
||||||
198
docs/module/fixtures.md
Normal file
198
docs/module/fixtures.md
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
# Fixtures
|
||||||
|
|
||||||
|
Dependency-aware database seeding with context-based loading strategies.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `fixtures` module lets you define named fixtures with dependencies between them, then load them into the database in the correct order. Fixtures can be scoped to contexts (e.g. base data, testing data) so that only the relevant ones are loaded for each environment.
|
||||||
|
|
||||||
|
## Defining fixtures
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.fixtures import FixtureRegistry, Context
|
||||||
|
|
||||||
|
fixtures = FixtureRegistry()
|
||||||
|
|
||||||
|
@fixtures.register
|
||||||
|
def roles():
|
||||||
|
return [
|
||||||
|
Role(id=1, name="admin"),
|
||||||
|
Role(id=2, name="user"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@fixtures.register(depends_on=["roles"], contexts=[Context.TESTING])
|
||||||
|
def test_users():
|
||||||
|
return [
|
||||||
|
User(id=1, username="alice", role_id=1),
|
||||||
|
User(id=2, username="bob", role_id=2),
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
Dependencies declared via `depends_on` are resolved topologically — `roles` will always be loaded before `test_users`.
|
||||||
|
|
||||||
|
## Loading fixtures
|
||||||
|
|
||||||
|
By context with [`load_fixtures_by_context`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.load_fixtures_by_context):
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.fixtures import load_fixtures_by_context
|
||||||
|
|
||||||
|
async with db_context() as session:
|
||||||
|
await load_fixtures_by_context(session, fixtures, Context.TESTING)
|
||||||
|
```
|
||||||
|
|
||||||
|
Directly by name with [`load_fixtures`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.load_fixtures):
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.fixtures import load_fixtures
|
||||||
|
|
||||||
|
async with db_context() as session:
|
||||||
|
await load_fixtures(session, fixtures, "roles", "test_users")
|
||||||
|
```
|
||||||
|
|
||||||
|
Both functions return a `dict[str, list[...]]` mapping each fixture name to the list of loaded instances.
|
||||||
|
|
||||||
|
## Contexts
|
||||||
|
|
||||||
|
[`Context`](../reference/fixtures.md#fastapi_toolsets.fixtures.enum.Context) is an enum with predefined values:
|
||||||
|
|
||||||
|
| Context | Description |
|
||||||
|
|---------|-------------|
|
||||||
|
| `Context.BASE` | Core data required in all environments |
|
||||||
|
| `Context.TESTING` | Data only loaded during tests |
|
||||||
|
| `Context.DEVELOPMENT` | Data only loaded in development |
|
||||||
|
| `Context.PRODUCTION` | Data only loaded in production |
|
||||||
|
|
||||||
|
A fixture with no `contexts` defined takes `Context.BASE` by default.
|
||||||
|
|
||||||
|
### Custom contexts
|
||||||
|
|
||||||
|
Plain strings and any `Enum` subclass are accepted wherever a `Context` enum is expected.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
class AppContext(str, Enum):
|
||||||
|
STAGING = "staging"
|
||||||
|
DEMO = "demo"
|
||||||
|
|
||||||
|
@fixtures.register(contexts=[AppContext.STAGING])
|
||||||
|
def staging_data():
|
||||||
|
return [Config(key="feature_x", enabled=True)]
|
||||||
|
|
||||||
|
await load_fixtures_by_context(session, fixtures, AppContext.STAGING)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Default context for a registry
|
||||||
|
|
||||||
|
Pass `contexts` to `FixtureRegistry` to set a default for all fixtures registered in it:
|
||||||
|
|
||||||
|
```python
|
||||||
|
testing_registry = FixtureRegistry(contexts=[Context.TESTING])
|
||||||
|
|
||||||
|
@testing_registry.register # implicitly contexts=[Context.TESTING]
|
||||||
|
def test_orders():
|
||||||
|
return [Order(id=1, total=99)]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Same fixture name, multiple context variants
|
||||||
|
|
||||||
|
The same fixture name may be registered under different (non-overlapping) context sets. When multiple contexts are loaded together, all matching variants are merged:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@fixtures.register(contexts=[Context.BASE])
|
||||||
|
def users():
|
||||||
|
return [User(id=1, username="admin")]
|
||||||
|
|
||||||
|
@fixtures.register(contexts=[Context.TESTING])
|
||||||
|
def users():
|
||||||
|
return [User(id=2, username="tester")]
|
||||||
|
|
||||||
|
# loads both admin and tester
|
||||||
|
await load_fixtures_by_context(session, fixtures, Context.BASE, Context.TESTING)
|
||||||
|
```
|
||||||
|
|
||||||
|
Registering two variants with overlapping context sets raises `ValueError`.
|
||||||
|
|
||||||
|
## Load strategies
|
||||||
|
|
||||||
|
[`LoadStrategy`](../reference/fixtures.md#fastapi_toolsets.fixtures.enum.LoadStrategy) controls how the fixture loader handles rows that already exist:
|
||||||
|
|
||||||
|
| Strategy | Description |
|
||||||
|
|----------|-------------|
|
||||||
|
| `LoadStrategy.INSERT` | Insert only, fail on duplicates |
|
||||||
|
| `LoadStrategy.MERGE` | Insert or update on conflict (default) |
|
||||||
|
| `LoadStrategy.SKIP_EXISTING` | Skip rows that already exist |
|
||||||
|
|
||||||
|
```python
|
||||||
|
await load_fixtures_by_context(
|
||||||
|
session, fixtures, Context.BASE, strategy=LoadStrategy.SKIP_EXISTING
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Merging registries
|
||||||
|
|
||||||
|
Split fixture definitions across modules and merge them:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from myapp.fixtures.dev import dev_fixtures
|
||||||
|
from myapp.fixtures.prod import prod_fixtures
|
||||||
|
|
||||||
|
fixtures = FixtureRegistry()
|
||||||
|
fixtures.include_registry(registry=dev_fixtures)
|
||||||
|
fixtures.include_registry(registry=prod_fixtures)
|
||||||
|
```
|
||||||
|
|
||||||
|
Fixtures with the same name are allowed as long as their context sets do not overlap. Conflicting contexts raise `ValueError`.
|
||||||
|
|
||||||
|
## Looking up fixture instances
|
||||||
|
|
||||||
|
[`get_obj_by_attr`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.get_obj_by_attr) retrieves a specific instance from a fixture function by attribute value — useful when building cross-fixture `depends_on` relationships:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.fixtures import get_obj_by_attr
|
||||||
|
|
||||||
|
@fixtures.register(depends_on=["roles"])
|
||||||
|
def users():
|
||||||
|
admin_role = get_obj_by_attr(roles, "name", "admin")
|
||||||
|
return [User(id=1, username="alice", role_id=admin_role.id)]
|
||||||
|
```
|
||||||
|
|
||||||
|
Raises `StopIteration` if no matching instance is found.
|
||||||
|
|
||||||
|
## Pytest integration
|
||||||
|
|
||||||
|
Use [`register_fixtures`](../reference/pytest.md#fastapi_toolsets.pytest.plugin.register_fixtures) to expose each fixture in your registry as an injectable pytest fixture named `fixture_{name}` by default:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# conftest.py
|
||||||
|
import pytest
|
||||||
|
from fastapi_toolsets.pytest import create_db_session, register_fixtures
|
||||||
|
from app.fixtures import registry
|
||||||
|
from app.models import Base
|
||||||
|
|
||||||
|
DATABASE_URL = "postgresql+asyncpg://user:pass@localhost/test_db"
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session():
|
||||||
|
async with create_db_session(database_url=DATABASE_URL, base=Base, cleanup=True) as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
register_fixtures(registry=registry, namespace=globals())
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# test_users.py
|
||||||
|
async def test_user_can_login(fixture_users: list[User], fixture_roles: list[Role]):
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
The load order is resolved automatically from the `depends_on` declarations in your registry. Each generated fixture receives `db_session` as a dependency and returns the list of loaded model instances.
|
||||||
|
|
||||||
|
## CLI integration
|
||||||
|
|
||||||
|
Fixtures can be triggered from the CLI. See the [CLI module](cli.md) for setup instructions.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/fixtures.md)
|
||||||
39
docs/module/logger.md
Normal file
39
docs/module/logger.md
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# Logger
|
||||||
|
|
||||||
|
Lightweight logging utilities with consistent formatting and uvicorn integration.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `logger` module provides two helpers: one to configure the root logger (and uvicorn loggers) at startup, and one to retrieve a named logger anywhere in your codebase.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
Call [`configure_logging`](../reference/logger.md#fastapi_toolsets.logger.configure_logging) once at application startup:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.logger import configure_logging
|
||||||
|
|
||||||
|
configure_logging(level="INFO")
|
||||||
|
```
|
||||||
|
|
||||||
|
This sets up a stdout handler with a consistent format and also configures uvicorn's access and error loggers so all log output shares the same style.
|
||||||
|
|
||||||
|
## Getting a logger
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__)
|
||||||
|
logger.info("User created")
|
||||||
|
```
|
||||||
|
|
||||||
|
When called without arguments, [`get_logger`](../reference/logger.md#fastapi_toolsets.logger.get_logger) auto-detects the caller's module name via frame inspection:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Equivalent to get_logger(name=__name__)
|
||||||
|
logger = get_logger()
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/logger.md)
|
||||||
109
docs/module/metrics.md
Normal file
109
docs/module/metrics.md
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
# Metrics
|
||||||
|
|
||||||
|
Prometheus metrics integration with a decorator-based registry and multi-process support.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
=== "uv"
|
||||||
|
``` bash
|
||||||
|
uv add "fastapi-toolsets[metrics]"
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "pip"
|
||||||
|
``` bash
|
||||||
|
pip install "fastapi-toolsets[metrics]"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `metrics` module provides a [`MetricsRegistry`](../reference/metrics.md#fastapi_toolsets.metrics.registry.MetricsRegistry) to declare Prometheus metrics with decorators, and an [`init_metrics`](../reference/metrics.md#fastapi_toolsets.metrics.handler.init_metrics) function to mount a `/metrics` endpoint on your FastAPI app.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi_toolsets.metrics import MetricsRegistry, init_metrics
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
metrics = MetricsRegistry()
|
||||||
|
|
||||||
|
init_metrics(app=app, registry=metrics)
|
||||||
|
```
|
||||||
|
|
||||||
|
This mounts the `/metrics` endpoint that Prometheus can scrape.
|
||||||
|
|
||||||
|
## Declaring metrics
|
||||||
|
|
||||||
|
### Providers
|
||||||
|
|
||||||
|
Providers are called once at startup by `init_metrics`. The return value (the Prometheus metric object) is stored in the registry and can be retrieved later with [`registry.get(name)`](../reference/metrics.md#fastapi_toolsets.metrics.registry.MetricsRegistry.get).
|
||||||
|
|
||||||
|
Use providers when you want **deferred initialization**: the Prometheus metric is not registered with the global `CollectorRegistry` until `init_metrics` runs, not at import time. This is particularly useful for testing — importing the module in a test suite without calling `init_metrics` leaves no metrics registered, avoiding cross-test pollution.
|
||||||
|
|
||||||
|
It is also useful when metrics are defined across multiple modules and merged with `include_registry`: any code that needs a metric can call `metrics.get()` on the shared registry instead of importing the metric directly from its origin module.
|
||||||
|
|
||||||
|
If neither of these applies to you, declaring metrics at module level (e.g. `HTTP_REQUESTS = Counter(...)`) is simpler and equally valid.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from prometheus_client import Counter, Histogram
|
||||||
|
|
||||||
|
@metrics.register
|
||||||
|
def http_requests():
|
||||||
|
return Counter("http_requests_total", "Total HTTP requests", ["method", "status"])
|
||||||
|
|
||||||
|
@metrics.register
|
||||||
|
def request_duration():
|
||||||
|
return Histogram("request_duration_seconds", "Request duration")
|
||||||
|
```
|
||||||
|
|
||||||
|
To use a provider's metric elsewhere (e.g. in a middleware), call `metrics.get()` inside the handler — **not** at module level, as providers are only initialized when `init_metrics` runs:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def metrics_middleware(request: Request, call_next):
|
||||||
|
response = await call_next(request)
|
||||||
|
metrics.get("http_requests").labels(
|
||||||
|
method=request.method, status=response.status_code
|
||||||
|
).inc()
|
||||||
|
return response
|
||||||
|
```
|
||||||
|
|
||||||
|
### Collectors
|
||||||
|
|
||||||
|
Collectors are called on every scrape. Use them for metrics that reflect current state (e.g. gauges).
|
||||||
|
|
||||||
|
!!! warning "Declare the metric at module level"
|
||||||
|
Do **not** instantiate the Prometheus metric inside the collector function. Doing so recreates it on every scrape, raising `ValueError: Duplicated timeseries in CollectorRegistry`. Declare it once at module level instead:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from prometheus_client import Gauge
|
||||||
|
|
||||||
|
_queue_depth = Gauge("queue_depth", "Current queue depth")
|
||||||
|
|
||||||
|
@metrics.register(collect=True)
|
||||||
|
def collect_queue_depth():
|
||||||
|
_queue_depth.set(get_current_queue_depth())
|
||||||
|
```
|
||||||
|
|
||||||
|
## Merging registries
|
||||||
|
|
||||||
|
Split metrics definitions across modules and merge them:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from myapp.metrics.http import http_metrics
|
||||||
|
from myapp.metrics.db import db_metrics
|
||||||
|
|
||||||
|
metrics = MetricsRegistry()
|
||||||
|
metrics.include_registry(registry=http_metrics)
|
||||||
|
metrics.include_registry(registry=db_metrics)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Multi-process mode
|
||||||
|
|
||||||
|
Multi-process support is enabled automatically when the `PROMETHEUS_MULTIPROC_DIR` environment variable is set. No code changes are required.
|
||||||
|
|
||||||
|
!!! warning "Environment variable name"
|
||||||
|
The correct variable is `PROMETHEUS_MULTIPROC_DIR` (not `PROMETHEUS_MULTIPROCESS_DIR`).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/metrics.md)
|
||||||
234
docs/module/models.md
Normal file
234
docs/module/models.md
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
# Models
|
||||||
|
|
||||||
|
!!! info "Added in `v2.0`"
|
||||||
|
|
||||||
|
Reusable SQLAlchemy 2.0 mixins for common column patterns, designed to be composed freely on any `DeclarativeBase` model.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `models` module provides mixins that each add a single, well-defined column behaviour. They work with standard SQLAlchemy 2.0 declarative syntax and are fully compatible with `AsyncSession`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.models import UUIDMixin, TimestampMixin
|
||||||
|
|
||||||
|
class Article(Base, UUIDMixin, TimestampMixin):
|
||||||
|
__tablename__ = "articles"
|
||||||
|
|
||||||
|
title: Mapped[str]
|
||||||
|
content: Mapped[str]
|
||||||
|
```
|
||||||
|
|
||||||
|
All timestamp columns are timezone-aware (`TIMESTAMPTZ`). All defaults are server-side (`clock_timestamp()`), so they are also applied when inserting rows via raw SQL outside the ORM.
|
||||||
|
|
||||||
|
## Mixins
|
||||||
|
|
||||||
|
### [`UUIDMixin`](../reference/models.md#fastapi_toolsets.models.UUIDMixin)
|
||||||
|
|
||||||
|
Adds a `id: UUID` primary key generated server-side by PostgreSQL using `gen_random_uuid()`. The value is retrieved via `RETURNING` after insert, so it is available on the Python object immediately after `flush()`.
|
||||||
|
|
||||||
|
!!! warning "Requires PostgreSQL 13+"
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.models import UUIDMixin
|
||||||
|
|
||||||
|
class User(Base, UUIDMixin):
|
||||||
|
__tablename__ = "users"
|
||||||
|
|
||||||
|
username: Mapped[str]
|
||||||
|
|
||||||
|
# id is None before flush
|
||||||
|
user = User(username="alice")
|
||||||
|
session.add(user)
|
||||||
|
await session.flush()
|
||||||
|
print(user.id) # UUID('...')
|
||||||
|
```
|
||||||
|
|
||||||
|
### [`UUIDv7Mixin`](../reference/models.md#fastapi_toolsets.models.UUIDv7Mixin)
|
||||||
|
|
||||||
|
!!! info "Added in `v2.3`"
|
||||||
|
|
||||||
|
Adds a `id: UUID` primary key generated server-side by PostgreSQL using `uuidv7()`. It's a time-ordered UUID format that encodes a millisecond-precision timestamp in the most significant bits, making it naturally sortable and index-friendly.
|
||||||
|
|
||||||
|
!!! warning "Requires PostgreSQL 18+"
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.models import UUIDv7Mixin
|
||||||
|
|
||||||
|
class Event(Base, UUIDv7Mixin):
|
||||||
|
__tablename__ = "events"
|
||||||
|
|
||||||
|
name: Mapped[str]
|
||||||
|
|
||||||
|
# id is None before flush
|
||||||
|
event = Event(name="user.signup")
|
||||||
|
session.add(event)
|
||||||
|
await session.flush()
|
||||||
|
print(event.id) # UUID('019...')
|
||||||
|
```
|
||||||
|
|
||||||
|
### [`CreatedAtMixin`](../reference/models.md#fastapi_toolsets.models.CreatedAtMixin)
|
||||||
|
|
||||||
|
Adds a `created_at: datetime` column set to `clock_timestamp()` on insert. The column has no `onupdate` hook — it is intentionally immutable after the row is created.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.models import UUIDMixin, CreatedAtMixin
|
||||||
|
|
||||||
|
class Order(Base, UUIDMixin, CreatedAtMixin):
|
||||||
|
__tablename__ = "orders"
|
||||||
|
|
||||||
|
total: Mapped[float]
|
||||||
|
```
|
||||||
|
|
||||||
|
### [`UpdatedAtMixin`](../reference/models.md#fastapi_toolsets.models.UpdatedAtMixin)
|
||||||
|
|
||||||
|
Adds an `updated_at: datetime` column set to `clock_timestamp()` on insert and automatically updated to `clock_timestamp()` on every ORM-level update (via SQLAlchemy's `onupdate` hook).
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.models import UUIDMixin, UpdatedAtMixin
|
||||||
|
|
||||||
|
class Post(Base, UUIDMixin, UpdatedAtMixin):
|
||||||
|
__tablename__ = "posts"
|
||||||
|
|
||||||
|
title: Mapped[str]
|
||||||
|
|
||||||
|
post = Post(title="Hello")
|
||||||
|
await session.flush()
|
||||||
|
await session.refresh(post)
|
||||||
|
|
||||||
|
post.title = "Hello World"
|
||||||
|
await session.flush()
|
||||||
|
await session.refresh(post)
|
||||||
|
print(post.updated_at)
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
`updated_at` is updated by SQLAlchemy at ORM flush time. If you update rows via raw SQL (e.g. `UPDATE posts SET ...`), the column will **not** be updated automatically — use a database trigger if you need that guarantee.
|
||||||
|
|
||||||
|
### [`TimestampMixin`](../reference/models.md#fastapi_toolsets.models.TimestampMixin)
|
||||||
|
|
||||||
|
Convenience mixin that combines [`CreatedAtMixin`](../reference/models.md#fastapi_toolsets.models.CreatedAtMixin) and [`UpdatedAtMixin`](../reference/models.md#fastapi_toolsets.models.UpdatedAtMixin). Equivalent to inheriting both.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.models import UUIDMixin, TimestampMixin
|
||||||
|
|
||||||
|
class Article(Base, UUIDMixin, TimestampMixin):
|
||||||
|
__tablename__ = "articles"
|
||||||
|
|
||||||
|
title: Mapped[str]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Lifecycle events
|
||||||
|
|
||||||
|
The event system provides lifecycle callbacks that fire **after commit**. If the transaction rolls back, no callback fires.
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
|
||||||
|
Event dispatch requires [`EventSession`](../reference/models.md#fastapi_toolsets.models.EventSession). Pass it as the session class when creating your session factory:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
|
from fastapi_toolsets.models import EventSession
|
||||||
|
|
||||||
|
engine = create_async_engine("postgresql+asyncpg://...")
|
||||||
|
SessionLocal = async_sessionmaker(engine, expire_on_commit=False, class_=EventSession)
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! info "Callbacks fire on `session.commit()` only — not on savepoints."
|
||||||
|
Savepoints created by [`get_transaction`](db.md) or `begin_nested()` do **not**
|
||||||
|
trigger callbacks. All events accumulated across flushes are dispatched once
|
||||||
|
when the outermost `commit()` is called.
|
||||||
|
|
||||||
|
### Events
|
||||||
|
|
||||||
|
Three event types are available, each corresponding to a [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) value:
|
||||||
|
|
||||||
|
| Event | Trigger |
|
||||||
|
|---|---|
|
||||||
|
| `ModelEvent.CREATE` | After `INSERT` commit |
|
||||||
|
| `ModelEvent.DELETE` | After `DELETE` commit |
|
||||||
|
| `ModelEvent.UPDATE` | After `UPDATE` commit on a watched field |
|
||||||
|
|
||||||
|
!!! warning "Callbacks fire only for ORM-level changes. Rows updated via raw SQL (`UPDATE ... SET ...`) are not detected."
|
||||||
|
|
||||||
|
### Watched fields
|
||||||
|
|
||||||
|
Set `__watched_fields__` on the model to restrict which field changes trigger `UPDATE` events. It must be a `tuple[str, ...]` — any other type raises `TypeError`:
|
||||||
|
|
||||||
|
| Class attribute | `UPDATE` behaviour |
|
||||||
|
|---|---|
|
||||||
|
| `__watched_fields__ = ("status", "role")` | Only fires when `status` or `role` changes |
|
||||||
|
| *(not set)* | Fires when **any** mapped field changes |
|
||||||
|
|
||||||
|
`__watched_fields__` is inherited through the class hierarchy via normal Python MRO. A subclass can override it:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class Order(Base, UUIDMixin):
|
||||||
|
__watched_fields__ = ("status",)
|
||||||
|
...
|
||||||
|
|
||||||
|
class UrgentOrder(Order):
|
||||||
|
# inherits __watched_fields__ = ("status",)
|
||||||
|
...
|
||||||
|
|
||||||
|
class PriorityOrder(Order):
|
||||||
|
__watched_fields__ = ("priority",)
|
||||||
|
# overrides parent — UPDATE fires only for priority changes
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
### Registering handlers
|
||||||
|
|
||||||
|
Register handlers with the [`listens_for`](../reference/models.md#fastapi_toolsets.models.listens_for) decorator. Every callback receives three arguments: the model instance, the [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) that triggered it, and a `changes` dict (`None` for `CREATE` and `DELETE`):
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.models import ModelEvent, UUIDMixin, listens_for
|
||||||
|
|
||||||
|
class Order(Base, UUIDMixin):
|
||||||
|
__tablename__ = "orders"
|
||||||
|
__watched_fields__ = ("status",)
|
||||||
|
|
||||||
|
status: Mapped[str]
|
||||||
|
|
||||||
|
@listens_for(Order, [ModelEvent.CREATE])
|
||||||
|
async def on_order_created(order: Order, event_type: ModelEvent, changes: None):
|
||||||
|
await notify_new_order(order.id)
|
||||||
|
|
||||||
|
@listens_for(Order, [ModelEvent.DELETE])
|
||||||
|
async def on_order_deleted(order: Order, event_type: ModelEvent, changes: None):
|
||||||
|
await notify_order_cancelled(order.id)
|
||||||
|
|
||||||
|
@listens_for(Order, [ModelEvent.UPDATE])
|
||||||
|
async def on_order_updated(order: Order, event_type: ModelEvent, changes: dict):
|
||||||
|
if "status" in changes:
|
||||||
|
await notify_status_change(order.id, changes["status"])
|
||||||
|
```
|
||||||
|
|
||||||
|
Multiple handlers can be registered for the same model and event. Handlers registered on a parent class also fire for subclass instances.
|
||||||
|
|
||||||
|
A single handler can listen for multiple events at once. When `event_types` is omitted, the handler fires for all events:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@listens_for(Order, [ModelEvent.CREATE, ModelEvent.UPDATE])
|
||||||
|
async def on_order_changed(order: Order, event_type: ModelEvent, changes: dict | None):
|
||||||
|
await invalidate_cache(order.id)
|
||||||
|
|
||||||
|
@listens_for(Order) # all events
|
||||||
|
async def on_any_order_event(order: Order, event_type: ModelEvent, changes: dict | None):
|
||||||
|
await audit_log(order.id, event_type)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Field changes format
|
||||||
|
|
||||||
|
The `changes` dict maps each watched field that changed to `{"old": ..., "new": ...}`. Only fields that actually changed are included. For `CREATE` and `DELETE` events, `changes` is `None`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# CREATE / DELETE → changes is None
|
||||||
|
# status changed → {"status": {"old": "pending", "new": "shipped"}}
|
||||||
|
# two fields changed → {"status": {...}, "assigned_to": {...}}
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! info "Multiple flushes in one transaction are merged: the earliest `old` and latest `new` are preserved, and `on_update` fires only once per commit."
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/models.md)
|
||||||
95
docs/module/pytest.md
Normal file
95
docs/module/pytest.md
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
# Pytest
|
||||||
|
|
||||||
|
Testing helpers for FastAPI applications with async client, database sessions, and parallel worker support.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
=== "uv"
|
||||||
|
``` bash
|
||||||
|
uv add "fastapi-toolsets[pytest]"
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "pip"
|
||||||
|
``` bash
|
||||||
|
pip install "fastapi-toolsets[pytest]"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `pytest` module provides utilities for setting up async test clients, managing test database sessions, and supporting parallel test execution with `pytest-xdist`.
|
||||||
|
|
||||||
|
## Creating an async client
|
||||||
|
|
||||||
|
Use [`create_async_client`](../reference/pytest.md#fastapi_toolsets.pytest.utils.create_async_client) to get an `httpx.AsyncClient` configured for your FastAPI app:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import create_async_client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def http_client(db_session):
|
||||||
|
async def _override_get_db():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
async with create_async_client(
|
||||||
|
app=app,
|
||||||
|
base_url="http://127.0.0.1/api/v1",
|
||||||
|
dependency_overrides={get_db: _override_get_db},
|
||||||
|
) as c:
|
||||||
|
yield c
|
||||||
|
```
|
||||||
|
|
||||||
|
## Database sessions in tests
|
||||||
|
|
||||||
|
Use [`create_db_session`](../reference/pytest.md#fastapi_toolsets.pytest.utils.create_db_session) to create an isolated `AsyncSession` for a test, combined with [`create_worker_database`](../reference/pytest.md#fastapi_toolsets.pytest.utils.create_worker_database) to set up a per-worker database:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import create_worker_database, create_db_session
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
async def worker_db_url():
|
||||||
|
async with create_worker_database(
|
||||||
|
database_url=str(settings.SQLALCHEMY_DATABASE_URI)
|
||||||
|
) as url:
|
||||||
|
yield url
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session(worker_db_url):
|
||||||
|
async with create_db_session(
|
||||||
|
database_url=worker_db_url, base=Base, cleanup=True
|
||||||
|
) as session:
|
||||||
|
yield session
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! info
|
||||||
|
In this example, the database is reset between each test using the argument `cleanup=True`.
|
||||||
|
|
||||||
|
Use [`worker_database_url`](../reference/pytest.md#fastapi_toolsets.pytest.utils.worker_database_url) to derive the per-worker URL manually if needed:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import worker_database_url
|
||||||
|
|
||||||
|
url = worker_database_url("postgresql+asyncpg://user:pass@localhost/test_db", default_test_db="test")
|
||||||
|
# e.g. "postgresql+asyncpg://user:pass@localhost/test_db_gw0" under xdist
|
||||||
|
```
|
||||||
|
|
||||||
|
## Parallel testing with pytest-xdist
|
||||||
|
|
||||||
|
The examples above are already compatible with parallel test execution with `pytest-xdist`.
|
||||||
|
|
||||||
|
## Cleaning up tables
|
||||||
|
|
||||||
|
If you want to manually clean up a database you can use [`cleanup_tables`](../reference/db.md#fastapi_toolsets.db.cleanup_tables), this will truncate all tables between tests for fast isolation:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import cleanup_tables
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
async def clean(db_session):
|
||||||
|
yield
|
||||||
|
await cleanup_tables(session=db_session, base=Base)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/pytest.md)
|
||||||
140
docs/module/schemas.md
Normal file
140
docs/module/schemas.md
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
# Schemas
|
||||||
|
|
||||||
|
Standardized Pydantic response models for consistent API responses across your FastAPI application.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `schemas` module provides generic response wrappers that enforce a uniform response structure. All models use `from_attributes=True` for ORM compatibility and `validate_assignment=True` for runtime type safety.
|
||||||
|
|
||||||
|
## Response models
|
||||||
|
|
||||||
|
### [`Response[T]`](../reference/schemas.md#fastapi_toolsets.schemas.Response)
|
||||||
|
|
||||||
|
The most common wrapper for a single resource response.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.schemas import Response
|
||||||
|
|
||||||
|
@router.get("/users/{id}")
|
||||||
|
async def get_user(user: User = UserDep) -> Response[UserSchema]:
|
||||||
|
return Response(data=user, message="User retrieved")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Paginated response models
|
||||||
|
|
||||||
|
Three classes wrap paginated list results. Pick the one that matches your endpoint's strategy:
|
||||||
|
|
||||||
|
| Class | `pagination` type | `pagination_type` field | Use when |
|
||||||
|
|---|---|---|---|
|
||||||
|
| [`OffsetPaginatedResponse[T]`](#offsetpaginatedresponset) | `OffsetPagination` | `"offset"` (fixed) | endpoint always uses offset |
|
||||||
|
| [`CursorPaginatedResponse[T]`](#cursorpaginatedresponset) | `CursorPagination` | `"cursor"` (fixed) | endpoint always uses cursor |
|
||||||
|
| [`PaginatedResponse[T]`](#paginatedresponset) | `OffsetPagination \| CursorPagination` | — | unified endpoint supporting both strategies |
|
||||||
|
|
||||||
|
#### [`OffsetPaginatedResponse[T]`](../reference/schemas.md#fastapi_toolsets.schemas.OffsetPaginatedResponse)
|
||||||
|
|
||||||
|
!!! info "Added in `v2.3.0`"
|
||||||
|
|
||||||
|
Use as the return type when the endpoint always uses [`offset_paginate`](crud.md#offset-pagination). The `pagination` field is guaranteed to be an [`OffsetPagination`](../reference/schemas.md#fastapi_toolsets.schemas.OffsetPagination) object; the response always includes a `pagination_type: "offset"` discriminator.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.schemas import OffsetPaginatedResponse
|
||||||
|
|
||||||
|
@router.get("/users")
|
||||||
|
async def list_users(
|
||||||
|
page: int = 1,
|
||||||
|
items_per_page: int = 20,
|
||||||
|
) -> OffsetPaginatedResponse[UserSchema]:
|
||||||
|
return await UserCrud.offset_paginate(
|
||||||
|
session, page=page, items_per_page=items_per_page, schema=UserSchema
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response shape:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"pagination_type": "offset",
|
||||||
|
"data": ["..."],
|
||||||
|
"pagination": {
|
||||||
|
"total_count": 100,
|
||||||
|
"page": 1,
|
||||||
|
"items_per_page": 20,
|
||||||
|
"has_more": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### [`CursorPaginatedResponse[T]`](../reference/schemas.md#fastapi_toolsets.schemas.CursorPaginatedResponse)
|
||||||
|
|
||||||
|
!!! info "Added in `v2.3.0`"
|
||||||
|
|
||||||
|
Use as the return type when the endpoint always uses [`cursor_paginate`](crud.md#cursor-pagination). The `pagination` field is guaranteed to be a [`CursorPagination`](../reference/schemas.md#fastapi_toolsets.schemas.CursorPagination) object; the response always includes a `pagination_type: "cursor"` discriminator.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.schemas import CursorPaginatedResponse
|
||||||
|
|
||||||
|
@router.get("/events")
|
||||||
|
async def list_events(
|
||||||
|
cursor: str | None = None,
|
||||||
|
items_per_page: int = 20,
|
||||||
|
) -> CursorPaginatedResponse[EventSchema]:
|
||||||
|
return await EventCrud.cursor_paginate(
|
||||||
|
session, cursor=cursor, items_per_page=items_per_page, schema=EventSchema
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response shape:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"pagination_type": "cursor",
|
||||||
|
"data": ["..."],
|
||||||
|
"pagination": {
|
||||||
|
"next_cursor": "eyJpZCI6IDQyfQ==",
|
||||||
|
"prev_cursor": null,
|
||||||
|
"items_per_page": 20,
|
||||||
|
"has_more": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### [`PaginatedResponse[T]`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse)
|
||||||
|
|
||||||
|
Return type for endpoints that support **both** pagination strategies via a `pagination_type` query parameter (using [`paginate()`](crud.md#unified-paginate--both-strategies-on-one-endpoint)).
|
||||||
|
|
||||||
|
When used as a return annotation, `PaginatedResponse[T]` automatically expands to `Annotated[Union[CursorPaginatedResponse[T], OffsetPaginatedResponse[T]], Field(discriminator="pagination_type")]`, so FastAPI emits a proper `oneOf` + discriminator in the OpenAPI schema with no extra boilerplate:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.crud import PaginationType
|
||||||
|
from fastapi_toolsets.schemas import PaginatedResponse
|
||||||
|
|
||||||
|
@router.get("/users")
|
||||||
|
async def list_users(
|
||||||
|
pagination_type: PaginationType = PaginationType.OFFSET,
|
||||||
|
page: int = 1,
|
||||||
|
cursor: str | None = None,
|
||||||
|
items_per_page: int = 20,
|
||||||
|
) -> PaginatedResponse[UserSchema]:
|
||||||
|
return await UserCrud.paginate(
|
||||||
|
session,
|
||||||
|
pagination_type=pagination_type,
|
||||||
|
page=page,
|
||||||
|
cursor=cursor,
|
||||||
|
items_per_page=items_per_page,
|
||||||
|
schema=UserSchema,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Pagination metadata models
|
||||||
|
|
||||||
|
The optional `filter_attributes` field is populated when `facet_fields` are configured on the CRUD class (see [Filter attributes](crud.md#filter-attributes-facets)). It is `None` by default and can be hidden from API responses with `response_model_exclude_none=True`.
|
||||||
|
|
||||||
|
### [`ErrorResponse`](../reference/schemas.md#fastapi_toolsets.schemas.ErrorResponse)
|
||||||
|
|
||||||
|
Returned automatically by the exceptions handler.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/schemas.md)
|
||||||
7
docs/overrides/main.html
Normal file
7
docs/overrides/main.html
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
{% extends "base.html" %} {% block extrahead %}
|
||||||
|
<script
|
||||||
|
defer
|
||||||
|
src="https://analytics.d3vyce.fr/script.js"
|
||||||
|
data-website-id="338b8816-7b99-4c6a-82f3-15595be3fd47"
|
||||||
|
></script>
|
||||||
|
{{ super() }} {% endblock %}
|
||||||
27
docs/reference/cli.md
Normal file
27
docs/reference/cli.md
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
# `cli`
|
||||||
|
|
||||||
|
Here's the reference for the CLI configuration helpers used to load settings from `pyproject.toml`.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.cli.config`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.cli.config import (
|
||||||
|
import_from_string,
|
||||||
|
get_config_value,
|
||||||
|
get_fixtures_registry,
|
||||||
|
get_db_context,
|
||||||
|
get_custom_cli,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.config.import_from_string
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.config.get_config_value
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.config.get_fixtures_registry
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.config.get_db_context
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.config.get_custom_cli
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.utils.async_command
|
||||||
20
docs/reference/crud.md
Normal file
20
docs/reference/crud.md
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
# `crud`
|
||||||
|
|
||||||
|
Here's the reference for the CRUD classes, factory, and search utilities.
|
||||||
|
|
||||||
|
You can import the main symbols from `fastapi_toolsets.crud`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.crud import CrudFactory, AsyncCrud
|
||||||
|
from fastapi_toolsets.crud.search import SearchConfig, get_searchable_fields, build_search_filters
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.crud.factory.AsyncCrud
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.crud.factory.CrudFactory
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.crud.search.SearchConfig
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.crud.search.get_searchable_fields
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.crud.search.build_search_filters
|
||||||
34
docs/reference/db.md
Normal file
34
docs/reference/db.md
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
# `db`
|
||||||
|
|
||||||
|
Here's the reference for all database session utilities, transaction helpers, and locking functions.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.db`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import (
|
||||||
|
LockMode,
|
||||||
|
cleanup_tables,
|
||||||
|
create_database,
|
||||||
|
create_db_dependency,
|
||||||
|
create_db_context,
|
||||||
|
get_transaction,
|
||||||
|
lock_tables,
|
||||||
|
wait_for_row_change,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.LockMode
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.create_db_dependency
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.create_db_context
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.get_transaction
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.lock_tables
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.wait_for_row_change
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.create_database
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.cleanup_tables
|
||||||
13
docs/reference/dependencies.md
Normal file
13
docs/reference/dependencies.md
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# `dependencies`
|
||||||
|
|
||||||
|
Here's the reference for the FastAPI dependency factory functions.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.dependencies`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.dependencies import PathDependency, BodyDependency
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.dependencies.PathDependency
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.dependencies.BodyDependency
|
||||||
40
docs/reference/exceptions.md
Normal file
40
docs/reference/exceptions.md
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
# `exceptions`
|
||||||
|
|
||||||
|
Here's the reference for all exception classes and handler utilities.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.exceptions`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.exceptions import (
|
||||||
|
ApiException,
|
||||||
|
UnauthorizedError,
|
||||||
|
ForbiddenError,
|
||||||
|
NotFoundError,
|
||||||
|
ConflictError,
|
||||||
|
NoSearchableFieldsError,
|
||||||
|
InvalidFacetFilterError,
|
||||||
|
InvalidOrderFieldError,
|
||||||
|
generate_error_responses,
|
||||||
|
init_exceptions_handlers,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.ApiException
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.UnauthorizedError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.ForbiddenError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.NotFoundError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.ConflictError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.NoSearchableFieldsError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.InvalidOrderFieldError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.generate_error_responses
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.handler.init_exceptions_handlers
|
||||||
31
docs/reference/fixtures.md
Normal file
31
docs/reference/fixtures.md
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
# `fixtures`
|
||||||
|
|
||||||
|
Here's the reference for the fixture registry, enums, and loading utilities.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.fixtures`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.fixtures import (
|
||||||
|
Context,
|
||||||
|
LoadStrategy,
|
||||||
|
Fixture,
|
||||||
|
FixtureRegistry,
|
||||||
|
load_fixtures,
|
||||||
|
load_fixtures_by_context,
|
||||||
|
get_obj_by_attr,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.enum.Context
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.enum.LoadStrategy
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.registry.Fixture
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.registry.FixtureRegistry
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.utils.load_fixtures
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.utils.load_fixtures_by_context
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.utils.get_obj_by_attr
|
||||||
13
docs/reference/logger.md
Normal file
13
docs/reference/logger.md
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# `logger`
|
||||||
|
|
||||||
|
Here's the reference for the logging utilities.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.logger`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.logger import configure_logging, get_logger
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.logger.configure_logging
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.logger.get_logger
|
||||||
15
docs/reference/metrics.md
Normal file
15
docs/reference/metrics.md
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# `metrics`
|
||||||
|
|
||||||
|
Here's the reference for the Prometheus metrics registry and endpoint handler.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.metrics`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.metrics import Metric, MetricsRegistry, init_metrics
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.metrics.registry.Metric
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.metrics.registry.MetricsRegistry
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.metrics.handler.init_metrics
|
||||||
34
docs/reference/models.md
Normal file
34
docs/reference/models.md
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
# `models`
|
||||||
|
|
||||||
|
Here's the reference for the SQLAlchemy model mixins provided by the `models` module.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.models`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.models import (
|
||||||
|
EventSession,
|
||||||
|
ModelEvent,
|
||||||
|
UUIDMixin,
|
||||||
|
UUIDv7Mixin,
|
||||||
|
CreatedAtMixin,
|
||||||
|
UpdatedAtMixin,
|
||||||
|
TimestampMixin,
|
||||||
|
listens_for,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.models.EventSession
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.models.ModelEvent
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.models.UUIDMixin
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.models.UUIDv7Mixin
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.models.CreatedAtMixin
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.models.UpdatedAtMixin
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.models.TimestampMixin
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.models.listens_for
|
||||||
26
docs/reference/pytest.md
Normal file
26
docs/reference/pytest.md
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# `pytest`
|
||||||
|
|
||||||
|
Here's the reference for all testing utilities and pytest fixtures.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.pytest`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import (
|
||||||
|
register_fixtures,
|
||||||
|
create_async_client,
|
||||||
|
create_db_session,
|
||||||
|
worker_database_url,
|
||||||
|
create_worker_database,
|
||||||
|
cleanup_tables,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.plugin.register_fixtures
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.utils.create_async_client
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.utils.create_db_session
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.utils.worker_database_url
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.utils.create_worker_database
|
||||||
46
docs/reference/schemas.md
Normal file
46
docs/reference/schemas.md
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
# `schemas`
|
||||||
|
|
||||||
|
Here's the reference for all response models and types provided by the `schemas` module.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.schemas`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.schemas import (
|
||||||
|
PydanticBase,
|
||||||
|
ResponseStatus,
|
||||||
|
ApiError,
|
||||||
|
BaseResponse,
|
||||||
|
Response,
|
||||||
|
ErrorResponse,
|
||||||
|
OffsetPagination,
|
||||||
|
CursorPagination,
|
||||||
|
PaginationType,
|
||||||
|
PaginatedResponse,
|
||||||
|
OffsetPaginatedResponse,
|
||||||
|
CursorPaginatedResponse,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.PydanticBase
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.ResponseStatus
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.ApiError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.BaseResponse
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.Response
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.ErrorResponse
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.OffsetPagination
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.CursorPagination
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.PaginationType
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.PaginatedResponse
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.OffsetPaginatedResponse
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.CursorPaginatedResponse
|
||||||
0
docs_src/__init__.py
Normal file
0
docs_src/__init__.py
Normal file
0
docs_src/examples/__init__.py
Normal file
0
docs_src/examples/__init__.py
Normal file
0
docs_src/examples/pagination_search/__init__.py
Normal file
0
docs_src/examples/pagination_search/__init__.py
Normal file
9
docs_src/examples/pagination_search/app.py
Normal file
9
docs_src/examples/pagination_search/app.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
||||||
|
|
||||||
|
from .routes import router
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app=app)
|
||||||
|
app.include_router(router=router)
|
||||||
21
docs_src/examples/pagination_search/crud.py
Normal file
21
docs_src/examples/pagination_search/crud.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
from fastapi_toolsets.crud import CrudFactory
|
||||||
|
|
||||||
|
from .models import Article, Category
|
||||||
|
|
||||||
|
ArticleCrud = CrudFactory(
|
||||||
|
model=Article,
|
||||||
|
cursor_column=Article.created_at,
|
||||||
|
searchable_fields=[ # default fields for full-text search
|
||||||
|
Article.title,
|
||||||
|
Article.body,
|
||||||
|
(Article.category, Category.name),
|
||||||
|
],
|
||||||
|
facet_fields=[ # fields exposed as filter dropdowns
|
||||||
|
Article.status,
|
||||||
|
(Article.category, Category.name),
|
||||||
|
],
|
||||||
|
order_fields=[ # fields exposed for client-driven ordering
|
||||||
|
Article.title,
|
||||||
|
Article.created_at,
|
||||||
|
],
|
||||||
|
)
|
||||||
17
docs_src/examples/pagination_search/db.py
Normal file
17
docs_src/examples/pagination_search/db.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from fastapi_toolsets.db import create_db_context, create_db_dependency
|
||||||
|
|
||||||
|
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/postgres"
|
||||||
|
|
||||||
|
engine = create_async_engine(url=DATABASE_URL, future=True)
|
||||||
|
async_session_maker = async_sessionmaker(bind=engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
get_db = create_db_dependency(session_maker=async_session_maker)
|
||||||
|
get_db_context = create_db_context(session_maker=async_session_maker)
|
||||||
|
|
||||||
|
|
||||||
|
SessionDep = Annotated[AsyncSession, Depends(get_db)]
|
||||||
34
docs_src/examples/pagination_search/models.py
Normal file
34
docs_src/examples/pagination_search/models.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
import uuid
|
||||||
|
|
||||||
|
from sqlalchemy import Boolean, ForeignKey, String, Text
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from fastapi_toolsets.models import CreatedAtMixin
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Category(Base):
|
||||||
|
__tablename__ = "categories"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(primary_key=True, default=uuid.uuid4)
|
||||||
|
name: Mapped[str] = mapped_column(String(64), unique=True)
|
||||||
|
|
||||||
|
articles: Mapped[list["Article"]] = relationship(back_populates="category")
|
||||||
|
|
||||||
|
|
||||||
|
class Article(Base, CreatedAtMixin):
|
||||||
|
__tablename__ = "articles"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(primary_key=True, default=uuid.uuid4)
|
||||||
|
title: Mapped[str] = mapped_column(String(256))
|
||||||
|
body: Mapped[str] = mapped_column(Text)
|
||||||
|
status: Mapped[str] = mapped_column(String(32))
|
||||||
|
published: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||||
|
category_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
|
ForeignKey("categories.id"), nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
|
category: Mapped["Category | None"] = relationship(back_populates="articles")
|
||||||
79
docs_src/examples/pagination_search/routes.py
Normal file
79
docs_src/examples/pagination_search/routes.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
from fastapi_toolsets.schemas import (
|
||||||
|
CursorPaginatedResponse,
|
||||||
|
OffsetPaginatedResponse,
|
||||||
|
PaginatedResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .crud import ArticleCrud
|
||||||
|
from .db import SessionDep
|
||||||
|
from .models import Article
|
||||||
|
from .schemas import ArticleRead
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/articles")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/offset")
|
||||||
|
async def list_articles_offset(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[
|
||||||
|
dict,
|
||||||
|
Depends(
|
||||||
|
ArticleCrud.offset_paginate_params(
|
||||||
|
default_page_size=20,
|
||||||
|
max_page_size=100,
|
||||||
|
default_order_field=Article.created_at,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
],
|
||||||
|
) -> OffsetPaginatedResponse[ArticleRead]:
|
||||||
|
return await ArticleCrud.offset_paginate(
|
||||||
|
session=session,
|
||||||
|
**params,
|
||||||
|
schema=ArticleRead,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/cursor")
|
||||||
|
async def list_articles_cursor(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[
|
||||||
|
dict,
|
||||||
|
Depends(
|
||||||
|
ArticleCrud.cursor_paginate_params(
|
||||||
|
default_page_size=20,
|
||||||
|
max_page_size=100,
|
||||||
|
default_order_field=Article.created_at,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
],
|
||||||
|
) -> CursorPaginatedResponse[ArticleRead]:
|
||||||
|
return await ArticleCrud.cursor_paginate(
|
||||||
|
session=session,
|
||||||
|
**params,
|
||||||
|
schema=ArticleRead,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/")
|
||||||
|
async def list_articles(
|
||||||
|
session: SessionDep,
|
||||||
|
params: Annotated[
|
||||||
|
dict,
|
||||||
|
Depends(
|
||||||
|
ArticleCrud.paginate_params(
|
||||||
|
default_page_size=20,
|
||||||
|
max_page_size=100,
|
||||||
|
default_order_field=Article.created_at,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
],
|
||||||
|
) -> PaginatedResponse[ArticleRead]:
|
||||||
|
return await ArticleCrud.paginate(
|
||||||
|
session,
|
||||||
|
**params,
|
||||||
|
schema=ArticleRead,
|
||||||
|
)
|
||||||
13
docs_src/examples/pagination_search/schemas.py
Normal file
13
docs_src/examples/pagination_search/schemas.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
import datetime
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from fastapi_toolsets.schemas import PydanticBase
|
||||||
|
|
||||||
|
|
||||||
|
class ArticleRead(PydanticBase):
|
||||||
|
id: uuid.UUID
|
||||||
|
created_at: datetime.datetime
|
||||||
|
title: str
|
||||||
|
status: str
|
||||||
|
published: bool
|
||||||
|
category_id: uuid.UUID | None
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "fastapi-toolsets"
|
name = "fastapi-toolsets"
|
||||||
version = "0.2.0"
|
version = "3.0.0"
|
||||||
description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL"
|
description = "Production-ready utilities for FastAPI applications"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
license-files = ["LICENSE"]
|
license-files = ["LICENSE"]
|
||||||
@@ -11,7 +11,7 @@ authors = [
|
|||||||
]
|
]
|
||||||
keywords = ["fastapi", "sqlalchemy", "postgresql"]
|
keywords = ["fastapi", "sqlalchemy", "postgresql"]
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Development Status :: 4 - Beta",
|
"Development Status :: 5 - Production/Stable",
|
||||||
"Framework :: AsyncIO",
|
"Framework :: AsyncIO",
|
||||||
"Framework :: FastAPI",
|
"Framework :: FastAPI",
|
||||||
"Framework :: Pydantic",
|
"Framework :: Pydantic",
|
||||||
@@ -24,18 +24,17 @@ classifiers = [
|
|||||||
"Programming Language :: Python :: 3.11",
|
"Programming Language :: Python :: 3.11",
|
||||||
"Programming Language :: Python :: 3.12",
|
"Programming Language :: Python :: 3.12",
|
||||||
"Programming Language :: Python :: 3.13",
|
"Programming Language :: Python :: 3.13",
|
||||||
|
"Programming Language :: Python :: 3.14",
|
||||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||||
"Topic :: Software Development :: Libraries",
|
"Topic :: Software Development :: Libraries",
|
||||||
"Topic :: Software Development",
|
"Topic :: Software Development",
|
||||||
"Typing :: Typed",
|
"Typing :: Typed",
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"fastapi>=0.100.0",
|
|
||||||
"sqlalchemy[asyncio]>=2.0",
|
|
||||||
"asyncpg>=0.29.0",
|
"asyncpg>=0.29.0",
|
||||||
|
"fastapi>=0.100.0",
|
||||||
"pydantic>=2.0",
|
"pydantic>=2.0",
|
||||||
"typer>=0.9.0",
|
"sqlalchemy[asyncio]>=2.0",
|
||||||
"httpx>=0.25.0",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
@@ -45,23 +44,49 @@ Repository = "https://github.com/d3vyce/fastapi-toolsets"
|
|||||||
Issues = "https://github.com/d3vyce/fastapi-toolsets/issues"
|
Issues = "https://github.com/d3vyce/fastapi-toolsets/issues"
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
test = [
|
cli = [
|
||||||
"pytest>=8.0.0",
|
"typer>=0.9.0",
|
||||||
"pytest-anyio>=0.0.0",
|
|
||||||
"coverage>=7.0.0",
|
|
||||||
"pytest-cov>=4.0.0",
|
|
||||||
]
|
]
|
||||||
dev = [
|
metrics = [
|
||||||
"fastapi-toolsets[test]",
|
"prometheus_client>=0.20.0",
|
||||||
"ruff>=0.1.0",
|
]
|
||||||
"ty>=0.0.1a0",
|
pytest = [
|
||||||
|
"httpx>=0.25.0",
|
||||||
|
"pytest-xdist>=3.0.0",
|
||||||
|
"pytest>=8.0.0",
|
||||||
|
]
|
||||||
|
all = [
|
||||||
|
"fastapi-toolsets[cli,metrics,pytest]",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
fastapi-toolsets = "fastapi_toolsets.cli:app"
|
manager = "fastapi_toolsets.cli.app:cli"
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
dev = [
|
||||||
|
{include-group = "tests"},
|
||||||
|
{include-group = "docs"},
|
||||||
|
"fastapi-toolsets[all]",
|
||||||
|
"prek>=0.3.8",
|
||||||
|
"ruff>=0.1.0",
|
||||||
|
"ty>=0.0.1a0",
|
||||||
|
]
|
||||||
|
tests = [
|
||||||
|
"coverage>=7.0.0",
|
||||||
|
"httpx>=0.25.0",
|
||||||
|
"pytest-anyio>=0.0.0",
|
||||||
|
"pytest-cov>=4.0.0",
|
||||||
|
"pytest-xdist>=3.0.0",
|
||||||
|
"pytest>=8.0.0",
|
||||||
|
]
|
||||||
|
docs = [
|
||||||
|
"mike",
|
||||||
|
"mkdocstrings-python>=2.0.2",
|
||||||
|
"zensical>=0.0.30",
|
||||||
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["uv_build>=0.9.26,<0.10.0"]
|
requires = ["uv_build>=0.10,<0.12.0"]
|
||||||
build-backend = "uv_build"
|
build-backend = "uv_build"
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
@@ -80,3 +105,6 @@ exclude_lines = [
|
|||||||
"if TYPE_CHECKING:",
|
"if TYPE_CHECKING:",
|
||||||
"raise NotImplementedError",
|
"raise NotImplementedError",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.uv.sources]
|
||||||
|
mike = { git = "https://github.com/squidfunk/mike.git", tag = "2.2.0+zensical-0.1.0" }
|
||||||
|
|||||||
@@ -21,4 +21,4 @@ Example usage:
|
|||||||
return Response(data={"user": user.username}, message="Success")
|
return Response(data={"user": user.username}, message="Success")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.2.0"
|
__version__ = "3.0.0"
|
||||||
|
|||||||
9
src/fastapi_toolsets/_imports.py
Normal file
9
src/fastapi_toolsets/_imports.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
"""Optional dependency helpers."""
|
||||||
|
|
||||||
|
|
||||||
|
def require_extra(package: str, extra: str) -> None:
|
||||||
|
"""Raise *ImportError* with an actionable install instruction."""
|
||||||
|
raise ImportError(
|
||||||
|
f"'{package}' is required to use this module. "
|
||||||
|
f"Install it with: pip install fastapi-toolsets[{extra}]"
|
||||||
|
)
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
"""CLI for FastAPI projects."""
|
"""CLI for FastAPI projects."""
|
||||||
|
|
||||||
from .app import app, register_command
|
from .utils import async_command
|
||||||
|
|
||||||
__all__ = ["app", "register_command"]
|
__all__ = ["async_command"]
|
||||||
|
|||||||
@@ -1,97 +1,37 @@
|
|||||||
"""Main CLI application."""
|
"""Main CLI application."""
|
||||||
|
|
||||||
import importlib.util
|
try:
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Annotated
|
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
except ImportError:
|
||||||
|
from .._imports import require_extra
|
||||||
|
|
||||||
from .commands import fixtures
|
require_extra(package="typer", extra="cli")
|
||||||
|
|
||||||
app = typer.Typer(
|
from ..logger import configure_logging
|
||||||
name="fastapi-utils",
|
from .config import get_custom_cli
|
||||||
|
from .pyproject import load_pyproject
|
||||||
|
|
||||||
|
# Use custom CLI if configured, otherwise create default one
|
||||||
|
_custom_cli = get_custom_cli()
|
||||||
|
|
||||||
|
if _custom_cli is not None:
|
||||||
|
cli = _custom_cli
|
||||||
|
else:
|
||||||
|
cli = typer.Typer(
|
||||||
|
name="manager",
|
||||||
help="CLI utilities for FastAPI projects.",
|
help="CLI utilities for FastAPI projects.",
|
||||||
no_args_is_help=True,
|
no_args_is_help=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Register built-in commands
|
_config = load_pyproject()
|
||||||
app.add_typer(fixtures.app, name="fixtures")
|
if _config.get("fixtures") and _config.get("db_context"):
|
||||||
|
from .commands.fixtures import fixture_cli
|
||||||
|
|
||||||
|
cli.add_typer(fixture_cli, name="fixtures")
|
||||||
|
|
||||||
|
|
||||||
def register_command(command: typer.Typer, name: str) -> None:
|
@cli.callback()
|
||||||
"""Register a custom command group.
|
def main(ctx: typer.Context) -> None:
|
||||||
|
|
||||||
Args:
|
|
||||||
command: Typer app for the command group
|
|
||||||
name: Name for the command group
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# In your project's cli.py:
|
|
||||||
import typer
|
|
||||||
from fastapi_toolsets.cli import app, register_command
|
|
||||||
|
|
||||||
my_commands = typer.Typer()
|
|
||||||
|
|
||||||
@my_commands.command()
|
|
||||||
def seed():
|
|
||||||
'''Seed the database.'''
|
|
||||||
...
|
|
||||||
|
|
||||||
register_command(my_commands, "db")
|
|
||||||
# Now available as: fastapi-utils db seed
|
|
||||||
"""
|
|
||||||
app.add_typer(command, name=name)
|
|
||||||
|
|
||||||
|
|
||||||
@app.callback()
|
|
||||||
def main(
|
|
||||||
ctx: typer.Context,
|
|
||||||
config: Annotated[
|
|
||||||
Path | None,
|
|
||||||
typer.Option(
|
|
||||||
"--config",
|
|
||||||
"-c",
|
|
||||||
help="Path to project config file (Python module with fixtures registry).",
|
|
||||||
envvar="FASTAPI_TOOLSETS_CONFIG",
|
|
||||||
),
|
|
||||||
] = None,
|
|
||||||
) -> None:
|
|
||||||
"""FastAPI utilities CLI."""
|
"""FastAPI utilities CLI."""
|
||||||
|
configure_logging()
|
||||||
ctx.ensure_object(dict)
|
ctx.ensure_object(dict)
|
||||||
|
|
||||||
if config:
|
|
||||||
ctx.obj["config_path"] = config
|
|
||||||
# Load the config module
|
|
||||||
config_module = _load_module_from_path(config)
|
|
||||||
ctx.obj["config_module"] = config_module
|
|
||||||
|
|
||||||
|
|
||||||
def _load_module_from_path(path: Path) -> object:
|
|
||||||
"""Load a Python module from a file path.
|
|
||||||
|
|
||||||
Handles both absolute and relative imports by adding the config's
|
|
||||||
parent directory to sys.path temporarily.
|
|
||||||
"""
|
|
||||||
path = path.resolve()
|
|
||||||
|
|
||||||
# Add the parent directory to sys.path to support relative imports
|
|
||||||
parent_dir = str(
|
|
||||||
path.parent.parent
|
|
||||||
) # Go up two levels (e.g., from app/cli_config.py to project root)
|
|
||||||
if parent_dir not in sys.path:
|
|
||||||
sys.path.insert(0, parent_dir)
|
|
||||||
|
|
||||||
# Also add immediate parent for direct module imports
|
|
||||||
immediate_parent = str(path.parent)
|
|
||||||
if immediate_parent not in sys.path:
|
|
||||||
sys.path.insert(0, immediate_parent)
|
|
||||||
|
|
||||||
spec = importlib.util.spec_from_file_location("config", path)
|
|
||||||
if spec is None or spec.loader is None:
|
|
||||||
raise typer.BadParameter(f"Cannot load module from {path}")
|
|
||||||
|
|
||||||
module = importlib.util.module_from_spec(spec)
|
|
||||||
sys.modules["config"] = module
|
|
||||||
spec.loader.exec_module(module)
|
|
||||||
return module
|
|
||||||
|
|||||||
@@ -1,138 +1,66 @@
|
|||||||
"""Fixture management commands."""
|
"""Fixture management commands."""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
from ...fixtures import Context, FixtureRegistry, LoadStrategy, load_fixtures_by_context
|
from ...fixtures import Context, LoadStrategy, load_fixtures_by_context
|
||||||
|
from ..config import get_db_context, get_fixtures_registry
|
||||||
|
from ..utils import async_command
|
||||||
|
|
||||||
app = typer.Typer(
|
fixture_cli = typer.Typer(
|
||||||
name="fixtures",
|
name="fixtures",
|
||||||
help="Manage database fixtures.",
|
help="Manage database fixtures.",
|
||||||
no_args_is_help=True,
|
no_args_is_help=True,
|
||||||
)
|
)
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
def _get_registry(ctx: typer.Context) -> FixtureRegistry:
|
@fixture_cli.command("list")
|
||||||
"""Get fixture registry from context."""
|
|
||||||
config = ctx.obj.get("config_module") if ctx.obj else None
|
|
||||||
if config is None:
|
|
||||||
raise typer.BadParameter(
|
|
||||||
"No config provided. Use --config to specify a config file with a 'fixtures' registry."
|
|
||||||
)
|
|
||||||
|
|
||||||
registry = getattr(config, "fixtures", None)
|
|
||||||
if registry is None:
|
|
||||||
raise typer.BadParameter(
|
|
||||||
"Config module must have a 'fixtures' attribute (FixtureRegistry instance)."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not isinstance(registry, FixtureRegistry):
|
|
||||||
raise typer.BadParameter(
|
|
||||||
f"'fixtures' must be a FixtureRegistry instance, got {type(registry).__name__}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return registry
|
|
||||||
|
|
||||||
|
|
||||||
def _get_db_context(ctx: typer.Context):
|
|
||||||
"""Get database context manager from config."""
|
|
||||||
config = ctx.obj.get("config_module") if ctx.obj else None
|
|
||||||
if config is None:
|
|
||||||
raise typer.BadParameter("No config provided.")
|
|
||||||
|
|
||||||
get_db_context = getattr(config, "get_db_context", None)
|
|
||||||
if get_db_context is None:
|
|
||||||
raise typer.BadParameter("Config module must have a 'get_db_context' function.")
|
|
||||||
|
|
||||||
return get_db_context
|
|
||||||
|
|
||||||
|
|
||||||
@app.command("list")
|
|
||||||
def list_fixtures(
|
def list_fixtures(
|
||||||
ctx: typer.Context,
|
ctx: typer.Context,
|
||||||
context: Annotated[
|
context: Annotated[
|
||||||
str | None,
|
Context | None,
|
||||||
typer.Option(
|
typer.Option(
|
||||||
"--context",
|
"--context",
|
||||||
"-c",
|
"-c",
|
||||||
help="Filter by context (base, production, development, testing).",
|
help="Filter by context.",
|
||||||
),
|
),
|
||||||
] = None,
|
] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""List all registered fixtures."""
|
"""List all registered fixtures."""
|
||||||
registry = _get_registry(ctx)
|
registry = get_fixtures_registry()
|
||||||
|
fixtures = registry.get_by_context(context.value) if context else registry.get_all()
|
||||||
if context:
|
|
||||||
fixtures = registry.get_by_context(context)
|
|
||||||
else:
|
|
||||||
fixtures = registry.get_all()
|
|
||||||
|
|
||||||
if not fixtures:
|
if not fixtures:
|
||||||
typer.echo("No fixtures found.")
|
print("No fixtures found.")
|
||||||
return
|
return
|
||||||
|
|
||||||
typer.echo(f"\n{'Name':<30} {'Contexts':<30} {'Dependencies'}")
|
table = Table("Name", "Contexts", "Dependencies")
|
||||||
typer.echo("-" * 80)
|
|
||||||
|
|
||||||
for fixture in fixtures:
|
for fixture in fixtures:
|
||||||
contexts = ", ".join(fixture.contexts)
|
contexts = ", ".join(fixture.contexts)
|
||||||
deps = ", ".join(fixture.depends_on) if fixture.depends_on else "-"
|
deps = ", ".join(fixture.depends_on) if fixture.depends_on else "-"
|
||||||
typer.echo(f"{fixture.name:<30} {contexts:<30} {deps}")
|
table.add_row(fixture.name, contexts, deps)
|
||||||
|
|
||||||
typer.echo(f"\nTotal: {len(fixtures)} fixture(s)")
|
console.print(table)
|
||||||
|
print(f"\nTotal: {len(fixtures)} fixture(s)")
|
||||||
|
|
||||||
|
|
||||||
@app.command("graph")
|
@fixture_cli.command("load")
|
||||||
def show_graph(
|
@async_command
|
||||||
ctx: typer.Context,
|
async def load(
|
||||||
fixture_name: Annotated[
|
|
||||||
str | None,
|
|
||||||
typer.Argument(help="Show dependencies for a specific fixture."),
|
|
||||||
] = None,
|
|
||||||
) -> None:
|
|
||||||
"""Show fixture dependency graph."""
|
|
||||||
registry = _get_registry(ctx)
|
|
||||||
|
|
||||||
if fixture_name:
|
|
||||||
try:
|
|
||||||
order = registry.resolve_dependencies(fixture_name)
|
|
||||||
typer.echo(f"\nDependency chain for '{fixture_name}':\n")
|
|
||||||
for i, name in enumerate(order):
|
|
||||||
indent = " " * i
|
|
||||||
arrow = "└─> " if i > 0 else ""
|
|
||||||
typer.echo(f"{indent}{arrow}{name}")
|
|
||||||
except KeyError:
|
|
||||||
typer.echo(f"Fixture '{fixture_name}' not found.", err=True)
|
|
||||||
raise typer.Exit(1)
|
|
||||||
else:
|
|
||||||
# Show full graph
|
|
||||||
fixtures = registry.get_all()
|
|
||||||
|
|
||||||
typer.echo("\nFixture Dependency Graph:\n")
|
|
||||||
for fixture in fixtures:
|
|
||||||
deps = (
|
|
||||||
f" -> [{', '.join(fixture.depends_on)}]" if fixture.depends_on else ""
|
|
||||||
)
|
|
||||||
typer.echo(f" {fixture.name}{deps}")
|
|
||||||
|
|
||||||
|
|
||||||
@app.command("load")
|
|
||||||
def load(
|
|
||||||
ctx: typer.Context,
|
ctx: typer.Context,
|
||||||
contexts: Annotated[
|
contexts: Annotated[
|
||||||
list[str] | None,
|
list[Context] | None,
|
||||||
typer.Argument(
|
typer.Argument(help="Contexts to load."),
|
||||||
help="Contexts to load (base, production, development, testing)."
|
|
||||||
),
|
|
||||||
] = None,
|
] = None,
|
||||||
strategy: Annotated[
|
strategy: Annotated[
|
||||||
str,
|
LoadStrategy,
|
||||||
typer.Option(
|
typer.Option("--strategy", "-s", help="Load strategy."),
|
||||||
"--strategy", "-s", help="Load strategy: merge, insert, skip_existing."
|
] = LoadStrategy.MERGE,
|
||||||
),
|
|
||||||
] = "merge",
|
|
||||||
dry_run: Annotated[
|
dry_run: Annotated[
|
||||||
bool,
|
bool,
|
||||||
typer.Option(
|
typer.Option(
|
||||||
@@ -141,85 +69,32 @@ def load(
|
|||||||
] = False,
|
] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Load fixtures into the database."""
|
"""Load fixtures into the database."""
|
||||||
registry = _get_registry(ctx)
|
registry = get_fixtures_registry()
|
||||||
get_db_context = _get_db_context(ctx)
|
db_context = get_db_context()
|
||||||
|
|
||||||
# Parse contexts
|
context_list = list(contexts) if contexts else [Context.BASE]
|
||||||
if contexts:
|
|
||||||
context_list = contexts
|
|
||||||
else:
|
|
||||||
context_list = [Context.BASE]
|
|
||||||
|
|
||||||
# Parse strategy
|
|
||||||
try:
|
|
||||||
load_strategy = LoadStrategy(strategy)
|
|
||||||
except ValueError:
|
|
||||||
typer.echo(
|
|
||||||
f"Invalid strategy: {strategy}. Use: merge, insert, skip_existing", err=True
|
|
||||||
)
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
# Resolve what will be loaded
|
|
||||||
ordered = registry.resolve_context_dependencies(*context_list)
|
ordered = registry.resolve_context_dependencies(*context_list)
|
||||||
|
|
||||||
if not ordered:
|
if not ordered:
|
||||||
typer.echo("No fixtures to load for the specified context(s).")
|
print("No fixtures to load for the specified context(s).")
|
||||||
return
|
return
|
||||||
|
|
||||||
typer.echo(f"\nFixtures to load ({load_strategy.value} strategy):")
|
print(f"\nFixtures to load ({strategy.value} strategy):")
|
||||||
for name in ordered:
|
for name in ordered:
|
||||||
fixture = registry.get(name)
|
fixture = registry.get(name)
|
||||||
instances = list(fixture.func())
|
instances = list(fixture.func())
|
||||||
model_name = type(instances[0]).__name__ if instances else "?"
|
model_name = type(instances[0]).__name__ if instances else "?"
|
||||||
typer.echo(f" - {name}: {len(instances)} {model_name}(s)")
|
print(f" - {name}: {len(instances)} {model_name}(s)")
|
||||||
|
|
||||||
if dry_run:
|
if dry_run:
|
||||||
typer.echo("\n[Dry run - no changes made]")
|
print("\n[Dry run - no changes made]")
|
||||||
return
|
return
|
||||||
|
|
||||||
typer.echo("\nLoading...")
|
async with db_context() as session:
|
||||||
|
|
||||||
async def do_load():
|
|
||||||
async with get_db_context() as session:
|
|
||||||
result = await load_fixtures_by_context(
|
result = await load_fixtures_by_context(
|
||||||
session, registry, *context_list, strategy=load_strategy
|
session, registry, *context_list, strategy=strategy
|
||||||
)
|
)
|
||||||
return result
|
|
||||||
|
|
||||||
result = asyncio.run(do_load())
|
|
||||||
|
|
||||||
total = sum(len(items) for items in result.values())
|
total = sum(len(items) for items in result.values())
|
||||||
typer.echo(f"\nLoaded {total} record(s) successfully.")
|
print(f"\nLoaded {total} record(s) successfully.")
|
||||||
|
|
||||||
|
|
||||||
@app.command("show")
|
|
||||||
def show_fixture(
|
|
||||||
ctx: typer.Context,
|
|
||||||
name: Annotated[str, typer.Argument(help="Fixture name to show.")],
|
|
||||||
) -> None:
|
|
||||||
"""Show details of a specific fixture."""
|
|
||||||
registry = _get_registry(ctx)
|
|
||||||
|
|
||||||
try:
|
|
||||||
fixture = registry.get(name)
|
|
||||||
except KeyError:
|
|
||||||
typer.echo(f"Fixture '{name}' not found.", err=True)
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
typer.echo(f"\nFixture: {fixture.name}")
|
|
||||||
typer.echo(f"Contexts: {', '.join(fixture.contexts)}")
|
|
||||||
typer.echo(
|
|
||||||
f"Dependencies: {', '.join(fixture.depends_on) if fixture.depends_on else 'None'}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Show instances
|
|
||||||
instances = list(fixture.func())
|
|
||||||
if instances:
|
|
||||||
model_name = type(instances[0]).__name__
|
|
||||||
typer.echo(f"\nInstances ({len(instances)} {model_name}):")
|
|
||||||
for instance in instances[:10]: # Limit to 10
|
|
||||||
typer.echo(f" - {instance!r}")
|
|
||||||
if len(instances) > 10:
|
|
||||||
typer.echo(f" ... and {len(instances) - 10} more")
|
|
||||||
else:
|
|
||||||
typer.echo("\nNo instances (empty fixture)")
|
|
||||||
|
|||||||
125
src/fastapi_toolsets/cli/config.py
Normal file
125
src/fastapi_toolsets/cli/config.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
"""CLI configuration and dynamic imports."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
from typing import TYPE_CHECKING, Any, Literal, overload
|
||||||
|
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from .pyproject import find_pyproject, load_pyproject
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..fixtures import FixtureRegistry
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_project_in_path():
|
||||||
|
"""Add project root to sys.path if not installed in editable mode."""
|
||||||
|
pyproject = find_pyproject()
|
||||||
|
if pyproject:
|
||||||
|
project_root = str(pyproject.parent)
|
||||||
|
if project_root not in sys.path:
|
||||||
|
sys.path.insert(0, project_root)
|
||||||
|
|
||||||
|
|
||||||
|
def import_from_string(import_path: str) -> Any:
|
||||||
|
"""Import an object from a dotted string path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
import_path: Import path in ``"module.submodule:attribute"`` format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The imported attribute
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
typer.BadParameter: If the import path is invalid or import fails
|
||||||
|
"""
|
||||||
|
if ":" not in import_path:
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"Invalid import path '{import_path}'. Expected format: 'module:attribute'"
|
||||||
|
)
|
||||||
|
|
||||||
|
module_path, attr_name = import_path.rsplit(":", 1)
|
||||||
|
|
||||||
|
_ensure_project_in_path()
|
||||||
|
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(module_path)
|
||||||
|
except ImportError as e:
|
||||||
|
raise typer.BadParameter(f"Cannot import module '{module_path}': {e}")
|
||||||
|
|
||||||
|
if not hasattr(module, attr_name):
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"Module '{module_path}' has no attribute '{attr_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
return getattr(module, attr_name)
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_config_value(key: str, required: Literal[True]) -> Any: ... # pragma: no cover
|
||||||
|
@overload
|
||||||
|
def get_config_value(
|
||||||
|
key: str, required: bool = False
|
||||||
|
) -> Any | None: ... # pragma: no cover
|
||||||
|
def get_config_value(key: str, required: bool = False) -> Any | None:
|
||||||
|
"""Get a configuration value from pyproject.toml.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The configuration key in [tool.fastapi-toolsets].
|
||||||
|
required: If True, raises an error when the key is missing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The configuration value, or None if not found and not required.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
typer.BadParameter: If required=True and the key is missing.
|
||||||
|
"""
|
||||||
|
config = load_pyproject()
|
||||||
|
value = config.get(key)
|
||||||
|
|
||||||
|
if required and value is None:
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"No '{key}' configured. "
|
||||||
|
f"Add '{key}' to [tool.fastapi-toolsets] in pyproject.toml."
|
||||||
|
)
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def get_fixtures_registry() -> FixtureRegistry:
|
||||||
|
"""Import and return the fixtures registry from config."""
|
||||||
|
from ..fixtures import FixtureRegistry
|
||||||
|
|
||||||
|
import_path = get_config_value("fixtures", required=True)
|
||||||
|
registry = import_from_string(import_path)
|
||||||
|
|
||||||
|
if not isinstance(registry, FixtureRegistry):
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"'fixtures' must be a FixtureRegistry instance, got {type(registry).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return registry
|
||||||
|
|
||||||
|
|
||||||
|
def get_db_context() -> Any:
|
||||||
|
"""Import and return the db_context function from config."""
|
||||||
|
import_path = get_config_value("db_context", required=True)
|
||||||
|
return import_from_string(import_path)
|
||||||
|
|
||||||
|
|
||||||
|
def get_custom_cli() -> typer.Typer | None:
|
||||||
|
"""Import and return the custom CLI Typer instance from config."""
|
||||||
|
import_path = get_config_value("custom_cli")
|
||||||
|
if not import_path:
|
||||||
|
return None
|
||||||
|
|
||||||
|
custom = import_from_string(import_path)
|
||||||
|
|
||||||
|
if not isinstance(custom, typer.Typer):
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"'custom_cli' must be a Typer instance, got {type(custom).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return custom
|
||||||
49
src/fastapi_toolsets/cli/pyproject.py
Normal file
49
src/fastapi_toolsets/cli/pyproject.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
"""Pyproject.toml discovery and loading."""
|
||||||
|
|
||||||
|
import tomllib
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
TOOL_NAME = "fastapi-toolsets"
|
||||||
|
|
||||||
|
|
||||||
|
def find_pyproject(start_path: Path | None = None) -> Path | None:
|
||||||
|
"""Find pyproject.toml by walking up the directory tree.
|
||||||
|
|
||||||
|
Similar to how pytest, black, and ruff discover their config files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_path: Directory to start searching from. Defaults to cwd.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to pyproject.toml, or None if not found.
|
||||||
|
"""
|
||||||
|
path = (start_path or Path.cwd()).resolve()
|
||||||
|
|
||||||
|
for directory in [path, *path.parents]:
|
||||||
|
pyproject = directory / "pyproject.toml"
|
||||||
|
if pyproject.is_file():
|
||||||
|
return pyproject
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def load_pyproject(path: Path | None = None) -> dict:
|
||||||
|
"""Load tool configuration from pyproject.toml.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Explicit path to pyproject.toml. If None, searches up from cwd.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The [tool.fastapi-toolsets] section as a dict, or empty dict if not found.
|
||||||
|
"""
|
||||||
|
pyproject_path = path or find_pyproject()
|
||||||
|
|
||||||
|
if not pyproject_path:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(pyproject_path, "rb") as f:
|
||||||
|
data = tomllib.load(f)
|
||||||
|
return data.get("tool", {}).get(TOOL_NAME, {})
|
||||||
|
except (OSError, tomllib.TOMLDecodeError):
|
||||||
|
return {}
|
||||||
29
src/fastapi_toolsets/cli/utils.py
Normal file
29
src/fastapi_toolsets/cli/utils.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
"""CLI utility functions."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import functools
|
||||||
|
from collections.abc import Callable, Coroutine
|
||||||
|
from typing import Any, ParamSpec, TypeVar
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def async_command(func: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, T]:
|
||||||
|
"""Decorator to run an async function as a sync CLI command.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
@fixture_cli.command("load")
|
||||||
|
@async_command
|
||||||
|
async def load(ctx: typer.Context) -> None:
|
||||||
|
async with get_db_context() as session:
|
||||||
|
await load_fixtures(session, registry)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||||
|
return asyncio.run(func(*args, **kwargs))
|
||||||
|
|
||||||
|
return wrapper
|
||||||
@@ -1,378 +0,0 @@
|
|||||||
"""Generic async CRUD operations for SQLAlchemy models."""
|
|
||||||
|
|
||||||
from collections.abc import Sequence
|
|
||||||
from typing import Any, ClassVar, Generic, Self, TypeVar, cast
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from sqlalchemy import and_, func, select
|
|
||||||
from sqlalchemy import delete as sql_delete
|
|
||||||
from sqlalchemy.dialects.postgresql import insert
|
|
||||||
from sqlalchemy.exc import NoResultFound
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
|
||||||
from sqlalchemy.sql.roles import WhereHavingRole
|
|
||||||
|
|
||||||
from .db import get_transaction
|
|
||||||
from .exceptions import NotFoundError
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"AsyncCrud",
|
|
||||||
"CrudFactory",
|
|
||||||
]
|
|
||||||
|
|
||||||
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncCrud(Generic[ModelType]):
|
|
||||||
"""Generic async CRUD operations for SQLAlchemy models.
|
|
||||||
|
|
||||||
Subclass this and set the `model` class variable, or use `CrudFactory`.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
class UserCrud(AsyncCrud[User]):
|
|
||||||
model = User
|
|
||||||
|
|
||||||
# Or use the factory:
|
|
||||||
UserCrud = CrudFactory(User)
|
|
||||||
|
|
||||||
# Then use it:
|
|
||||||
user = await UserCrud.get(session, [User.id == 1])
|
|
||||||
users = await UserCrud.get_multi(session, limit=10)
|
|
||||||
"""
|
|
||||||
|
|
||||||
model: ClassVar[type[DeclarativeBase]]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def create(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
obj: BaseModel,
|
|
||||||
) -> ModelType:
|
|
||||||
"""Create a new record in the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
obj: Pydantic model with data to create
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created model instance
|
|
||||||
"""
|
|
||||||
async with get_transaction(session):
|
|
||||||
db_model = cls.model(**obj.model_dump())
|
|
||||||
session.add(db_model)
|
|
||||||
await session.refresh(db_model)
|
|
||||||
return cast(ModelType, db_model)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def get(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
filters: list[Any],
|
|
||||||
*,
|
|
||||||
with_for_update: bool = False,
|
|
||||||
load_options: list[Any] | None = None,
|
|
||||||
) -> ModelType:
|
|
||||||
"""Get exactly one record. Raises NotFoundError if not found.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
filters: List of SQLAlchemy filter conditions
|
|
||||||
with_for_update: Lock the row for update
|
|
||||||
load_options: SQLAlchemy loader options (e.g., selectinload)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Model instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NotFoundError: If no record found
|
|
||||||
MultipleResultsFound: If more than one record found
|
|
||||||
"""
|
|
||||||
q = select(cls.model).where(and_(*filters))
|
|
||||||
if load_options:
|
|
||||||
q = q.options(*load_options)
|
|
||||||
if with_for_update:
|
|
||||||
q = q.with_for_update()
|
|
||||||
result = await session.execute(q)
|
|
||||||
item = result.unique().scalar_one_or_none()
|
|
||||||
if not item:
|
|
||||||
raise NotFoundError()
|
|
||||||
return cast(ModelType, item)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def first(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
filters: list[Any] | None = None,
|
|
||||||
*,
|
|
||||||
load_options: list[Any] | None = None,
|
|
||||||
) -> ModelType | None:
|
|
||||||
"""Get the first matching record, or None.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
filters: List of SQLAlchemy filter conditions
|
|
||||||
load_options: SQLAlchemy loader options
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Model instance or None
|
|
||||||
"""
|
|
||||||
q = select(cls.model)
|
|
||||||
if filters:
|
|
||||||
q = q.where(and_(*filters))
|
|
||||||
if load_options:
|
|
||||||
q = q.options(*load_options)
|
|
||||||
result = await session.execute(q)
|
|
||||||
return cast(ModelType | None, result.unique().scalars().first())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def get_multi(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
*,
|
|
||||||
filters: list[Any] | None = None,
|
|
||||||
load_options: list[Any] | None = None,
|
|
||||||
order_by: Any | None = None,
|
|
||||||
limit: int | None = None,
|
|
||||||
offset: int | None = None,
|
|
||||||
) -> Sequence[ModelType]:
|
|
||||||
"""Get multiple records from the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
filters: List of SQLAlchemy filter conditions
|
|
||||||
load_options: SQLAlchemy loader options
|
|
||||||
order_by: Column or list of columns to order by
|
|
||||||
limit: Max number of rows to return
|
|
||||||
offset: Rows to skip
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of model instances
|
|
||||||
"""
|
|
||||||
q = select(cls.model)
|
|
||||||
if filters:
|
|
||||||
q = q.where(and_(*filters))
|
|
||||||
if load_options:
|
|
||||||
q = q.options(*load_options)
|
|
||||||
if order_by is not None:
|
|
||||||
q = q.order_by(order_by)
|
|
||||||
if offset is not None:
|
|
||||||
q = q.offset(offset)
|
|
||||||
if limit is not None:
|
|
||||||
q = q.limit(limit)
|
|
||||||
result = await session.execute(q)
|
|
||||||
return cast(Sequence[ModelType], result.unique().scalars().all())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def update(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
obj: BaseModel,
|
|
||||||
filters: list[Any],
|
|
||||||
*,
|
|
||||||
exclude_unset: bool = True,
|
|
||||||
exclude_none: bool = False,
|
|
||||||
) -> ModelType:
|
|
||||||
"""Update a record in the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
obj: Pydantic model with update data
|
|
||||||
filters: List of SQLAlchemy filter conditions
|
|
||||||
exclude_unset: Exclude fields not explicitly set in the schema
|
|
||||||
exclude_none: Exclude fields with None value
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated model instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NotFoundError: If no record found
|
|
||||||
"""
|
|
||||||
async with get_transaction(session):
|
|
||||||
db_model = await cls.get(session=session, filters=filters)
|
|
||||||
values = obj.model_dump(
|
|
||||||
exclude_unset=exclude_unset, exclude_none=exclude_none
|
|
||||||
)
|
|
||||||
for key, value in values.items():
|
|
||||||
setattr(db_model, key, value)
|
|
||||||
await session.refresh(db_model)
|
|
||||||
return db_model
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def upsert(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
obj: BaseModel,
|
|
||||||
index_elements: list[str],
|
|
||||||
*,
|
|
||||||
set_: BaseModel | None = None,
|
|
||||||
where: WhereHavingRole | None = None,
|
|
||||||
) -> ModelType | None:
|
|
||||||
"""Create or update a record (PostgreSQL only).
|
|
||||||
|
|
||||||
Uses INSERT ... ON CONFLICT for atomic upsert.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
obj: Pydantic model with data
|
|
||||||
index_elements: Columns for ON CONFLICT (unique constraint)
|
|
||||||
set_: Pydantic model for ON CONFLICT DO UPDATE SET
|
|
||||||
where: WHERE clause for ON CONFLICT DO UPDATE
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Model instance
|
|
||||||
"""
|
|
||||||
async with get_transaction(session):
|
|
||||||
values = obj.model_dump(exclude_unset=True)
|
|
||||||
q = insert(cls.model).values(**values)
|
|
||||||
if set_:
|
|
||||||
q = q.on_conflict_do_update(
|
|
||||||
index_elements=index_elements,
|
|
||||||
set_=set_.model_dump(exclude_unset=True),
|
|
||||||
where=where,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
q = q.on_conflict_do_nothing(index_elements=index_elements)
|
|
||||||
q = q.returning(cls.model)
|
|
||||||
result = await session.execute(q)
|
|
||||||
try:
|
|
||||||
db_model = result.unique().scalar_one()
|
|
||||||
except NoResultFound:
|
|
||||||
db_model = await cls.first(
|
|
||||||
session=session,
|
|
||||||
filters=[getattr(cls.model, k) == v for k, v in values.items()],
|
|
||||||
)
|
|
||||||
return cast(ModelType | None, db_model)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def delete(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
filters: list[Any],
|
|
||||||
) -> bool:
|
|
||||||
"""Delete records from the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
filters: List of SQLAlchemy filter conditions
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if deletion was executed
|
|
||||||
"""
|
|
||||||
async with get_transaction(session):
|
|
||||||
q = sql_delete(cls.model).where(and_(*filters))
|
|
||||||
await session.execute(q)
|
|
||||||
return True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def count(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
filters: list[Any] | None = None,
|
|
||||||
) -> int:
|
|
||||||
"""Count records matching the filters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
filters: List of SQLAlchemy filter conditions
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of matching records
|
|
||||||
"""
|
|
||||||
q = select(func.count()).select_from(cls.model)
|
|
||||||
if filters:
|
|
||||||
q = q.where(and_(*filters))
|
|
||||||
result = await session.execute(q)
|
|
||||||
return result.scalar_one()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def exists(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
filters: list[Any],
|
|
||||||
) -> bool:
|
|
||||||
"""Check if a record exists.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
filters: List of SQLAlchemy filter conditions
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if at least one record matches
|
|
||||||
"""
|
|
||||||
q = select(cls.model).where(and_(*filters)).exists().select()
|
|
||||||
result = await session.execute(q)
|
|
||||||
return bool(result.scalar())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def paginate(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
*,
|
|
||||||
filters: list[Any] | None = None,
|
|
||||||
load_options: list[Any] | None = None,
|
|
||||||
order_by: Any | None = None,
|
|
||||||
page: int = 1,
|
|
||||||
items_per_page: int = 20,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Get paginated results with metadata.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
filters: List of SQLAlchemy filter conditions
|
|
||||||
load_options: SQLAlchemy loader options
|
|
||||||
order_by: Column or list of columns to order by
|
|
||||||
page: Page number (1-indexed)
|
|
||||||
items_per_page: Number of items per page
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict with 'data' and 'pagination' keys
|
|
||||||
"""
|
|
||||||
filters = filters or []
|
|
||||||
offset = (page - 1) * items_per_page
|
|
||||||
|
|
||||||
items = await cls.get_multi(
|
|
||||||
session,
|
|
||||||
filters=filters,
|
|
||||||
load_options=load_options,
|
|
||||||
order_by=order_by,
|
|
||||||
limit=items_per_page,
|
|
||||||
offset=offset,
|
|
||||||
)
|
|
||||||
|
|
||||||
total_count = await cls.count(session, filters=filters)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"data": items,
|
|
||||||
"pagination": {
|
|
||||||
"total_count": total_count,
|
|
||||||
"items_per_page": items_per_page,
|
|
||||||
"page": page,
|
|
||||||
"has_more": page * items_per_page < total_count,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def CrudFactory(
|
|
||||||
model: type[ModelType],
|
|
||||||
) -> type[AsyncCrud[ModelType]]:
|
|
||||||
"""Create a CRUD class for a specific model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: SQLAlchemy model class
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AsyncCrud subclass bound to the model
|
|
||||||
|
|
||||||
Example:
|
|
||||||
from fastapi_toolsets.crud import CrudFactory
|
|
||||||
from myapp.models import User, Post
|
|
||||||
|
|
||||||
UserCrud = CrudFactory(User)
|
|
||||||
PostCrud = CrudFactory(Post)
|
|
||||||
|
|
||||||
# Usage
|
|
||||||
user = await UserCrud.get(session, [User.id == 1])
|
|
||||||
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
|
|
||||||
"""
|
|
||||||
cls = type(f"Async{model.__name__}Crud", (AsyncCrud,), {"model": model})
|
|
||||||
return cast(type[AsyncCrud[ModelType]], cls)
|
|
||||||
35
src/fastapi_toolsets/crud/__init__.py
Normal file
35
src/fastapi_toolsets/crud/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""Generic async CRUD operations for SQLAlchemy models."""
|
||||||
|
|
||||||
|
from ..exceptions import (
|
||||||
|
InvalidFacetFilterError,
|
||||||
|
InvalidSearchColumnError,
|
||||||
|
NoSearchableFieldsError,
|
||||||
|
UnsupportedFacetTypeError,
|
||||||
|
)
|
||||||
|
from ..schemas import PaginationType
|
||||||
|
from ..types import (
|
||||||
|
FacetFieldType,
|
||||||
|
JoinType,
|
||||||
|
M2MFieldType,
|
||||||
|
OrderByClause,
|
||||||
|
SearchFieldType,
|
||||||
|
)
|
||||||
|
from .factory import AsyncCrud, CrudFactory
|
||||||
|
from .search import SearchConfig, get_searchable_fields
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AsyncCrud",
|
||||||
|
"CrudFactory",
|
||||||
|
"FacetFieldType",
|
||||||
|
"get_searchable_fields",
|
||||||
|
"InvalidFacetFilterError",
|
||||||
|
"InvalidSearchColumnError",
|
||||||
|
"JoinType",
|
||||||
|
"M2MFieldType",
|
||||||
|
"NoSearchableFieldsError",
|
||||||
|
"OrderByClause",
|
||||||
|
"PaginationType",
|
||||||
|
"SearchConfig",
|
||||||
|
"SearchFieldType",
|
||||||
|
"UnsupportedFacetTypeError",
|
||||||
|
]
|
||||||
1771
src/fastapi_toolsets/crud/factory.py
Normal file
1771
src/fastapi_toolsets/crud/factory.py
Normal file
File diff suppressed because it is too large
Load Diff
345
src/fastapi_toolsets/crud/search.py
Normal file
345
src/fastapi_toolsets/crud/search.py
Normal file
@@ -0,0 +1,345 @@
|
|||||||
|
"""Search utilities for AsyncCrud."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import functools
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass, replace
|
||||||
|
from typing import TYPE_CHECKING, Any, Literal
|
||||||
|
|
||||||
|
from sqlalchemy import String, and_, func, or_, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||||
|
from sqlalchemy.types import (
|
||||||
|
ARRAY,
|
||||||
|
Boolean,
|
||||||
|
Date,
|
||||||
|
DateTime,
|
||||||
|
Enum,
|
||||||
|
Integer,
|
||||||
|
Numeric,
|
||||||
|
Time,
|
||||||
|
Uuid,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..exceptions import (
|
||||||
|
InvalidFacetFilterError,
|
||||||
|
InvalidSearchColumnError,
|
||||||
|
NoSearchableFieldsError,
|
||||||
|
UnsupportedFacetTypeError,
|
||||||
|
)
|
||||||
|
from ..types import FacetFieldType, SearchFieldType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sqlalchemy.sql.elements import ColumnElement
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SearchConfig:
|
||||||
|
"""Advanced search configuration.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
query: The search string
|
||||||
|
fields: Fields to search (columns or tuples for relationships)
|
||||||
|
case_sensitive: Case-sensitive search (default: False)
|
||||||
|
match_mode: "any" (OR) or "all" (AND) to combine fields
|
||||||
|
"""
|
||||||
|
|
||||||
|
query: str
|
||||||
|
fields: Sequence[SearchFieldType] | None = None
|
||||||
|
case_sensitive: bool = False
|
||||||
|
match_mode: Literal["any", "all"] = "any"
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=128)
|
||||||
|
def get_searchable_fields(
|
||||||
|
model: type[DeclarativeBase],
|
||||||
|
*,
|
||||||
|
include_relationships: bool = True,
|
||||||
|
max_depth: int = 1,
|
||||||
|
) -> list[SearchFieldType]:
|
||||||
|
"""Auto-detect String fields on a model and its relationships.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: SQLAlchemy model class
|
||||||
|
include_relationships: Include fields from many-to-one/one-to-one relationships
|
||||||
|
max_depth: Max depth for relationship traversal (default: 1)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of columns and tuples (relationship, column)
|
||||||
|
"""
|
||||||
|
fields: list[SearchFieldType] = []
|
||||||
|
mapper = model.__mapper__
|
||||||
|
|
||||||
|
# Direct String columns
|
||||||
|
for col in mapper.columns:
|
||||||
|
if isinstance(col.type, String):
|
||||||
|
fields.append(getattr(model, col.key))
|
||||||
|
|
||||||
|
# Relationships (one-to-one, many-to-one only)
|
||||||
|
if include_relationships and max_depth > 0:
|
||||||
|
for rel_name, rel_prop in mapper.relationships.items():
|
||||||
|
if rel_prop.uselist: # Skip collections (one-to-many, many-to-many)
|
||||||
|
continue
|
||||||
|
|
||||||
|
rel_attr = getattr(model, rel_name)
|
||||||
|
related_model = rel_prop.mapper.class_
|
||||||
|
|
||||||
|
for col in related_model.__mapper__.columns:
|
||||||
|
if isinstance(col.type, String):
|
||||||
|
fields.append((rel_attr, getattr(related_model, col.key)))
|
||||||
|
|
||||||
|
return fields
|
||||||
|
|
||||||
|
|
||||||
|
def build_search_filters(
|
||||||
|
model: type[DeclarativeBase],
|
||||||
|
search: str | SearchConfig,
|
||||||
|
search_fields: Sequence[SearchFieldType] | None = None,
|
||||||
|
default_fields: Sequence[SearchFieldType] | None = None,
|
||||||
|
search_column: str | None = None,
|
||||||
|
) -> tuple[list["ColumnElement[bool]"], list[InstrumentedAttribute[Any]]]:
|
||||||
|
"""Build SQLAlchemy filter conditions for search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: SQLAlchemy model class
|
||||||
|
search: Search string or SearchConfig
|
||||||
|
search_fields: Fields specified per-call (takes priority)
|
||||||
|
default_fields: Default fields (from ClassVar)
|
||||||
|
search_column: Optional key to narrow search to a single field.
|
||||||
|
Must match one of the resolved search field keys.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (filter_conditions, joins_needed)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NoSearchableFieldsError: If no searchable field has been configured
|
||||||
|
"""
|
||||||
|
# Normalize input
|
||||||
|
if isinstance(search, str):
|
||||||
|
config = SearchConfig(query=search, fields=search_fields)
|
||||||
|
else:
|
||||||
|
config = (
|
||||||
|
replace(search, fields=search_fields)
|
||||||
|
if search_fields is not None
|
||||||
|
else search
|
||||||
|
)
|
||||||
|
|
||||||
|
if not config.query or not config.query.strip():
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
# Determine which fields to search
|
||||||
|
fields = config.fields or default_fields or get_searchable_fields(model)
|
||||||
|
|
||||||
|
if not fields:
|
||||||
|
raise NoSearchableFieldsError(model)
|
||||||
|
|
||||||
|
# Narrow to a single column when search_column is specified
|
||||||
|
if search_column is not None:
|
||||||
|
keys = search_field_keys(fields)
|
||||||
|
index = {k: f for k, f in zip(keys, fields)}
|
||||||
|
if search_column not in index:
|
||||||
|
raise InvalidSearchColumnError(search_column, sorted(index))
|
||||||
|
fields = [index[search_column]]
|
||||||
|
|
||||||
|
query = config.query.strip()
|
||||||
|
filters: list[ColumnElement[bool]] = []
|
||||||
|
joins: list[InstrumentedAttribute[Any]] = []
|
||||||
|
added_joins: set[str] = set()
|
||||||
|
|
||||||
|
for field in fields:
|
||||||
|
if isinstance(field, tuple):
|
||||||
|
# Relationship: (User.role, Role.name) or deeper
|
||||||
|
for rel in field[:-1]:
|
||||||
|
rel_key = str(rel)
|
||||||
|
if rel_key not in added_joins:
|
||||||
|
joins.append(rel)
|
||||||
|
added_joins.add(rel_key)
|
||||||
|
column = field[-1]
|
||||||
|
else:
|
||||||
|
column = field
|
||||||
|
|
||||||
|
# Build the filter (cast to String for non-text columns)
|
||||||
|
column_as_string = column.cast(String)
|
||||||
|
if config.case_sensitive:
|
||||||
|
filters.append(column_as_string.like(f"%{query}%"))
|
||||||
|
else:
|
||||||
|
filters.append(column_as_string.ilike(f"%{query}%"))
|
||||||
|
|
||||||
|
if not filters: # pragma: no cover
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
# Combine based on match_mode
|
||||||
|
if config.match_mode == "any":
|
||||||
|
return [or_(*filters)], joins
|
||||||
|
else:
|
||||||
|
return filters, joins
|
||||||
|
|
||||||
|
|
||||||
|
def search_field_keys(fields: Sequence[SearchFieldType]) -> list[str]:
|
||||||
|
"""Return a human-readable key for each search field."""
|
||||||
|
return facet_keys(fields)
|
||||||
|
|
||||||
|
|
||||||
|
def facet_keys(facet_fields: Sequence[FacetFieldType]) -> list[str]:
|
||||||
|
"""Return a key for each facet field.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
facet_fields: Sequence of facet fields — either direct columns or
|
||||||
|
relationship tuples ``(rel, ..., column)``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of string keys, one per facet field, in the same order.
|
||||||
|
"""
|
||||||
|
keys: list[str] = []
|
||||||
|
for field in facet_fields:
|
||||||
|
if isinstance(field, tuple):
|
||||||
|
keys.append("__".join(el.key for el in field))
|
||||||
|
else:
|
||||||
|
keys.append(field.key)
|
||||||
|
return keys
|
||||||
|
|
||||||
|
|
||||||
|
async def build_facets(
|
||||||
|
session: "AsyncSession",
|
||||||
|
model: type[DeclarativeBase],
|
||||||
|
facet_fields: Sequence[FacetFieldType],
|
||||||
|
*,
|
||||||
|
base_filters: "list[ColumnElement[bool]] | None" = None,
|
||||||
|
base_joins: list[InstrumentedAttribute[Any]] | None = None,
|
||||||
|
) -> dict[str, list[Any]]:
|
||||||
|
"""Return distinct values for each facet field, respecting current filters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: DB async session
|
||||||
|
model: SQLAlchemy model class
|
||||||
|
facet_fields: Columns or relationship tuples to facet on
|
||||||
|
base_filters: Filter conditions already applied to the main query (search + caller filters)
|
||||||
|
base_joins: Relationship joins already applied to the main query
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping column key to sorted list of distinct non-None values
|
||||||
|
"""
|
||||||
|
existing_join_keys: set[str] = {str(j) for j in (base_joins or [])}
|
||||||
|
|
||||||
|
keys = facet_keys(facet_fields)
|
||||||
|
|
||||||
|
async def _query_facet(field: FacetFieldType, key: str) -> tuple[str, list[Any]]:
|
||||||
|
if isinstance(field, tuple):
|
||||||
|
# Relationship chain: (User.role, Role.name) — last element is the column
|
||||||
|
rels = field[:-1]
|
||||||
|
column = field[-1]
|
||||||
|
else:
|
||||||
|
rels = ()
|
||||||
|
column = field
|
||||||
|
|
||||||
|
col_type = column.property.columns[0].type
|
||||||
|
is_array = isinstance(col_type, ARRAY)
|
||||||
|
|
||||||
|
if is_array:
|
||||||
|
unnested = func.unnest(column).label(column.key)
|
||||||
|
q = select(unnested).select_from(model).distinct()
|
||||||
|
else:
|
||||||
|
q = select(column).select_from(model).distinct()
|
||||||
|
|
||||||
|
# Apply base joins (deduplicated) — needed here independently
|
||||||
|
seen_joins: set[str] = set()
|
||||||
|
for rel in base_joins or []:
|
||||||
|
rel_key = str(rel)
|
||||||
|
if rel_key not in seen_joins:
|
||||||
|
seen_joins.add(rel_key)
|
||||||
|
q = q.outerjoin(rel)
|
||||||
|
|
||||||
|
# Add any extra joins required by this facet field that aren't already applied
|
||||||
|
for rel in rels:
|
||||||
|
rel_key = str(rel)
|
||||||
|
if rel_key not in existing_join_keys and rel_key not in seen_joins:
|
||||||
|
seen_joins.add(rel_key)
|
||||||
|
q = q.outerjoin(rel)
|
||||||
|
|
||||||
|
if base_filters:
|
||||||
|
q = q.where(and_(*base_filters))
|
||||||
|
|
||||||
|
if is_array:
|
||||||
|
q = q.order_by(unnested)
|
||||||
|
else:
|
||||||
|
q = q.order_by(column)
|
||||||
|
result = await session.execute(q)
|
||||||
|
values = [row[0] for row in result.all() if row[0] is not None]
|
||||||
|
return key, values
|
||||||
|
|
||||||
|
pairs = await asyncio.gather(
|
||||||
|
*[_query_facet(f, k) for f, k in zip(facet_fields, keys)]
|
||||||
|
)
|
||||||
|
return dict(pairs)
|
||||||
|
|
||||||
|
|
||||||
|
_EQUALITY_TYPES = (String, Integer, Numeric, Date, DateTime, Time, Enum, Uuid)
|
||||||
|
"""Column types that support equality / IN filtering in build_filter_by."""
|
||||||
|
|
||||||
|
|
||||||
|
def build_filter_by(
|
||||||
|
filter_by: dict[str, Any],
|
||||||
|
facet_fields: Sequence[FacetFieldType],
|
||||||
|
) -> tuple["list[ColumnElement[bool]]", list[InstrumentedAttribute[Any]]]:
|
||||||
|
"""Translate a {column_key: value} dict into SQLAlchemy filter conditions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter_by: Mapping of column key to scalar value or list of values
|
||||||
|
facet_fields: Declared facet fields to validate keys against
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (filter_conditions, joins_needed)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidFacetFilterError: If a key in filter_by is not a declared facet field
|
||||||
|
"""
|
||||||
|
index: dict[
|
||||||
|
str, tuple[InstrumentedAttribute[Any], list[InstrumentedAttribute[Any]]]
|
||||||
|
] = {}
|
||||||
|
for key, field in zip(facet_keys(facet_fields), facet_fields):
|
||||||
|
if isinstance(field, tuple):
|
||||||
|
rels = list(field[:-1])
|
||||||
|
column = field[-1]
|
||||||
|
else:
|
||||||
|
rels = []
|
||||||
|
column = field
|
||||||
|
index[key] = (column, rels)
|
||||||
|
|
||||||
|
valid_keys = set(index)
|
||||||
|
filters: list[ColumnElement[bool]] = []
|
||||||
|
joins: list[InstrumentedAttribute[Any]] = []
|
||||||
|
added_join_keys: set[str] = set()
|
||||||
|
|
||||||
|
for key, value in filter_by.items():
|
||||||
|
if key not in index:
|
||||||
|
raise InvalidFacetFilterError(key, valid_keys)
|
||||||
|
|
||||||
|
column, rels = index[key]
|
||||||
|
|
||||||
|
for rel in rels:
|
||||||
|
rel_key = str(rel)
|
||||||
|
if rel_key not in added_join_keys:
|
||||||
|
joins.append(rel)
|
||||||
|
added_join_keys.add(rel_key)
|
||||||
|
|
||||||
|
col_type = column.property.columns[0].type
|
||||||
|
if isinstance(col_type, ARRAY):
|
||||||
|
if isinstance(value, list):
|
||||||
|
filters.append(column.overlap(value))
|
||||||
|
else:
|
||||||
|
filters.append(column.any(value))
|
||||||
|
elif isinstance(col_type, Boolean):
|
||||||
|
if isinstance(value, list):
|
||||||
|
filters.append(column.in_(value))
|
||||||
|
else:
|
||||||
|
filters.append(column.is_(value))
|
||||||
|
elif isinstance(col_type, _EQUALITY_TYPES):
|
||||||
|
if isinstance(value, list):
|
||||||
|
filters.append(column.in_(value))
|
||||||
|
else:
|
||||||
|
filters.append(column == value)
|
||||||
|
else:
|
||||||
|
raise UnsupportedFacetTypeError(key, type(col_type).__name__)
|
||||||
|
|
||||||
|
return filters, joins
|
||||||
@@ -1,25 +1,35 @@
|
|||||||
"""Database utilities: sessions, transactions, and locks."""
|
"""Database utilities: sessions, transactions, and locks."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from collections.abc import AsyncGenerator, Callable
|
from collections.abc import AsyncGenerator, Callable
|
||||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from .exceptions import NotFoundError
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LockMode",
|
"LockMode",
|
||||||
|
"cleanup_tables",
|
||||||
|
"create_database",
|
||||||
"create_db_context",
|
"create_db_context",
|
||||||
"create_db_dependency",
|
"create_db_dependency",
|
||||||
"lock_tables",
|
|
||||||
"get_transaction",
|
"get_transaction",
|
||||||
|
"lock_tables",
|
||||||
|
"wait_for_row_change",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
_SessionT = TypeVar("_SessionT", bound=AsyncSession)
|
||||||
|
|
||||||
|
|
||||||
def create_db_dependency(
|
def create_db_dependency(
|
||||||
session_maker: async_sessionmaker[AsyncSession],
|
session_maker: async_sessionmaker[_SessionT],
|
||||||
) -> Callable[[], AsyncGenerator[AsyncSession, None]]:
|
) -> Callable[[], AsyncGenerator[_SessionT, None]]:
|
||||||
"""Create a FastAPI dependency for database sessions.
|
"""Create a FastAPI dependency for database sessions.
|
||||||
|
|
||||||
Creates a dependency function that yields a session and auto-commits
|
Creates a dependency function that yields a session and auto-commits
|
||||||
@@ -32,6 +42,7 @@ def create_db_dependency(
|
|||||||
An async generator function usable with FastAPI's Depends()
|
An async generator function usable with FastAPI's Depends()
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||||
from fastapi_toolsets.db import create_db_dependency
|
from fastapi_toolsets.db import create_db_dependency
|
||||||
@@ -43,10 +54,12 @@ def create_db_dependency(
|
|||||||
@app.get("/users")
|
@app.get("/users")
|
||||||
async def list_users(session: AsyncSession = Depends(get_db)):
|
async def list_users(session: AsyncSession = Depends(get_db)):
|
||||||
...
|
...
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
async def get_db() -> AsyncGenerator[_SessionT, None]:
|
||||||
async with session_maker() as session:
|
async with session_maker() as session:
|
||||||
|
await session.connection()
|
||||||
yield session
|
yield session
|
||||||
if session.in_transaction():
|
if session.in_transaction():
|
||||||
await session.commit()
|
await session.commit()
|
||||||
@@ -55,8 +68,8 @@ def create_db_dependency(
|
|||||||
|
|
||||||
|
|
||||||
def create_db_context(
|
def create_db_context(
|
||||||
session_maker: async_sessionmaker[AsyncSession],
|
session_maker: async_sessionmaker[_SessionT],
|
||||||
) -> Callable[[], AbstractAsyncContextManager[AsyncSession]]:
|
) -> Callable[[], AbstractAsyncContextManager[_SessionT]]:
|
||||||
"""Create a context manager for database sessions.
|
"""Create a context manager for database sessions.
|
||||||
|
|
||||||
Creates a context manager for use outside of FastAPI request handlers,
|
Creates a context manager for use outside of FastAPI request handlers,
|
||||||
@@ -69,6 +82,7 @@ def create_db_context(
|
|||||||
An async context manager function
|
An async context manager function
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||||
from fastapi_toolsets.db import create_db_context
|
from fastapi_toolsets.db import create_db_context
|
||||||
|
|
||||||
@@ -80,6 +94,7 @@ def create_db_context(
|
|||||||
async with get_db_context() as session:
|
async with get_db_context() as session:
|
||||||
user = await UserCrud.get(session, [User.id == 1])
|
user = await UserCrud.get(session, [User.id == 1])
|
||||||
...
|
...
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
get_db = create_db_dependency(session_maker)
|
get_db = create_db_dependency(session_maker)
|
||||||
return asynccontextmanager(get_db)
|
return asynccontextmanager(get_db)
|
||||||
@@ -101,9 +116,11 @@ async def get_transaction(
|
|||||||
The session within the transaction context
|
The session within the transaction context
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
async with get_transaction(session):
|
async with get_transaction(session):
|
||||||
session.add(model)
|
session.add(model)
|
||||||
# Auto-commits on exit, rolls back on exception
|
# Auto-commits on exit, rolls back on exception
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
if session.in_transaction():
|
if session.in_transaction():
|
||||||
async with session.begin_nested():
|
async with session.begin_nested():
|
||||||
@@ -155,6 +172,7 @@ async def lock_tables(
|
|||||||
SQLAlchemyError: If lock cannot be acquired within timeout
|
SQLAlchemyError: If lock cannot be acquired within timeout
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
from fastapi_toolsets.db import lock_tables, LockMode
|
from fastapi_toolsets.db import lock_tables, LockMode
|
||||||
|
|
||||||
async with lock_tables(session, [User, Account]):
|
async with lock_tables(session, [User, Account]):
|
||||||
@@ -166,6 +184,7 @@ async def lock_tables(
|
|||||||
async with lock_tables(session, [Order], mode=LockMode.EXCLUSIVE):
|
async with lock_tables(session, [Order], mode=LockMode.EXCLUSIVE):
|
||||||
# Exclusive lock - no other transactions can access
|
# Exclusive lock - no other transactions can access
|
||||||
await process_order(session, order_id)
|
await process_order(session, order_id)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
table_names = ",".join(table.__tablename__ for table in tables)
|
table_names = ",".join(table.__tablename__ for table in tables)
|
||||||
|
|
||||||
@@ -173,3 +192,150 @@ async def lock_tables(
|
|||||||
await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
|
await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
|
||||||
await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
|
await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
async def create_database(
|
||||||
|
db_name: str,
|
||||||
|
*,
|
||||||
|
server_url: str,
|
||||||
|
) -> None:
|
||||||
|
"""Create a database.
|
||||||
|
|
||||||
|
Connects to *server_url* using ``AUTOCOMMIT`` isolation and issues a
|
||||||
|
``CREATE DATABASE`` statement for *db_name*.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_name: Name of the database to create.
|
||||||
|
server_url: URL used for server-level DDL (must point to an existing
|
||||||
|
database on the same server).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import create_database
|
||||||
|
|
||||||
|
SERVER_URL = "postgresql+asyncpg://postgres:postgres@localhost/postgres"
|
||||||
|
await create_database("myapp_test", server_url=SERVER_URL)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
engine = create_async_engine(server_url, isolation_level="AUTOCOMMIT")
|
||||||
|
try:
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
await conn.execute(text(f"CREATE DATABASE {db_name}"))
|
||||||
|
finally:
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
async def cleanup_tables(
|
||||||
|
session: AsyncSession,
|
||||||
|
base: type[DeclarativeBase],
|
||||||
|
) -> None:
|
||||||
|
"""Truncate all tables for fast between-test cleanup.
|
||||||
|
|
||||||
|
Executes a single ``TRUNCATE … RESTART IDENTITY CASCADE`` statement
|
||||||
|
across every table in *base*'s metadata, which is significantly faster
|
||||||
|
than dropping and re-creating tables between tests.
|
||||||
|
|
||||||
|
This is a no-op when the metadata contains no tables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: An active async database session.
|
||||||
|
base: SQLAlchemy DeclarativeBase class containing model metadata.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session(worker_db_url):
|
||||||
|
async with create_db_session(worker_db_url, Base) as session:
|
||||||
|
yield session
|
||||||
|
await cleanup_tables(session, Base)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
tables = base.metadata.sorted_tables
|
||||||
|
if not tables:
|
||||||
|
return
|
||||||
|
|
||||||
|
table_names = ", ".join(f'"{t.name}"' for t in tables)
|
||||||
|
await session.execute(text(f"TRUNCATE {table_names} RESTART IDENTITY CASCADE"))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
_M = TypeVar("_M", bound=DeclarativeBase)
|
||||||
|
|
||||||
|
|
||||||
|
async def wait_for_row_change(
|
||||||
|
session: AsyncSession,
|
||||||
|
model: type[_M],
|
||||||
|
pk_value: Any,
|
||||||
|
*,
|
||||||
|
columns: list[str] | None = None,
|
||||||
|
interval: float = 0.5,
|
||||||
|
timeout: float | None = None,
|
||||||
|
) -> _M:
|
||||||
|
"""Poll a database row until a change is detected.
|
||||||
|
|
||||||
|
Queries the row every ``interval`` seconds and returns the model instance
|
||||||
|
once a change is detected in any column (or only the specified ``columns``).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: AsyncSession instance
|
||||||
|
model: SQLAlchemy model class
|
||||||
|
pk_value: Primary key value of the row to watch
|
||||||
|
columns: Optional list of column names to watch. If None, all columns
|
||||||
|
are watched.
|
||||||
|
interval: Polling interval in seconds (default: 0.5)
|
||||||
|
timeout: Maximum time to wait in seconds. None means wait forever.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The refreshed model instance with updated values
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If the row does not exist or is deleted during polling
|
||||||
|
TimeoutError: If timeout expires before a change is detected
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import wait_for_row_change
|
||||||
|
|
||||||
|
# Wait for any column to change
|
||||||
|
updated = await wait_for_row_change(session, User, user_id)
|
||||||
|
|
||||||
|
# Watch specific columns with a timeout
|
||||||
|
updated = await wait_for_row_change(
|
||||||
|
session, User, user_id,
|
||||||
|
columns=["status", "email"],
|
||||||
|
interval=1.0,
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
instance = await session.get(model, pk_value)
|
||||||
|
if instance is None:
|
||||||
|
raise NotFoundError(f"{model.__name__} with pk={pk_value!r} not found")
|
||||||
|
|
||||||
|
if columns is not None:
|
||||||
|
watch_cols = columns
|
||||||
|
else:
|
||||||
|
watch_cols = [attr.key for attr in model.__mapper__.column_attrs]
|
||||||
|
|
||||||
|
initial = {col: getattr(instance, col) for col in watch_cols}
|
||||||
|
|
||||||
|
elapsed = 0.0
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
elapsed += interval
|
||||||
|
|
||||||
|
if timeout is not None and elapsed >= timeout:
|
||||||
|
raise TimeoutError(
|
||||||
|
f"No change detected on {model.__name__} "
|
||||||
|
f"with pk={pk_value!r} within {timeout}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
session.expunge(instance)
|
||||||
|
instance = await session.get(model, pk_value)
|
||||||
|
|
||||||
|
if instance is None:
|
||||||
|
raise NotFoundError(f"{model.__name__} with pk={pk_value!r} was deleted")
|
||||||
|
|
||||||
|
current = {col: getattr(instance, col) for col in watch_cols}
|
||||||
|
if current != initial:
|
||||||
|
return instance
|
||||||
|
|||||||
155
src/fastapi_toolsets/dependencies.py
Normal file
155
src/fastapi_toolsets/dependencies.py
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
"""Dependency factories for FastAPI routes."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import typing
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
from fastapi.params import Depends as DependsClass
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from .crud import CrudFactory
|
||||||
|
from .types import ModelType, SessionDependency
|
||||||
|
|
||||||
|
__all__ = ["BodyDependency", "PathDependency"]
|
||||||
|
|
||||||
|
|
||||||
|
def _unwrap_session_dep(session_dep: SessionDependency) -> Callable[..., Any]:
|
||||||
|
"""Extract the plain callable from ``Annotated[AsyncSession, Depends(fn)]`` if needed."""
|
||||||
|
if typing.get_origin(session_dep) is typing.Annotated:
|
||||||
|
for arg in typing.get_args(session_dep)[1:]:
|
||||||
|
if isinstance(arg, DependsClass):
|
||||||
|
return arg.dependency
|
||||||
|
return session_dep
|
||||||
|
|
||||||
|
|
||||||
|
def PathDependency(
|
||||||
|
model: type[ModelType],
|
||||||
|
field: Any,
|
||||||
|
*,
|
||||||
|
session_dep: SessionDependency,
|
||||||
|
param_name: str | None = None,
|
||||||
|
) -> ModelType:
|
||||||
|
"""Create a dependency that fetches a DB object from a path parameter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: SQLAlchemy model class
|
||||||
|
field: Model field to filter by (e.g., User.id)
|
||||||
|
session_dep: Session dependency function (e.g., get_db)
|
||||||
|
param_name: Path parameter name (defaults to model_field, e.g., user_id)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Depends() instance that resolves to the model instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If no matching record is found
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
UserDep = PathDependency(User, User.id, session_dep=get_db)
|
||||||
|
|
||||||
|
@router.get("/user/{id}")
|
||||||
|
async def get(
|
||||||
|
user: User = UserDep,
|
||||||
|
): ...
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
session_callable = _unwrap_session_dep(session_dep)
|
||||||
|
crud = CrudFactory(model)
|
||||||
|
name = (
|
||||||
|
param_name
|
||||||
|
if param_name is not None
|
||||||
|
else "{}_{}".format(model.__name__.lower(), field.key)
|
||||||
|
)
|
||||||
|
python_type = field.type.python_type
|
||||||
|
|
||||||
|
async def dependency(
|
||||||
|
session: AsyncSession = Depends(session_callable), **kwargs: Any
|
||||||
|
) -> ModelType:
|
||||||
|
value = kwargs[name]
|
||||||
|
return await crud.get(session, filters=[field == value])
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
dependency,
|
||||||
|
"__signature__",
|
||||||
|
inspect.Signature(
|
||||||
|
parameters=[
|
||||||
|
inspect.Parameter(
|
||||||
|
name, inspect.Parameter.KEYWORD_ONLY, annotation=python_type
|
||||||
|
),
|
||||||
|
inspect.Parameter(
|
||||||
|
"session",
|
||||||
|
inspect.Parameter.KEYWORD_ONLY,
|
||||||
|
annotation=AsyncSession,
|
||||||
|
default=Depends(session_callable),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return cast(ModelType, Depends(cast(Callable[..., ModelType], dependency)))
|
||||||
|
|
||||||
|
|
||||||
|
def BodyDependency(
|
||||||
|
model: type[ModelType],
|
||||||
|
field: Any,
|
||||||
|
*,
|
||||||
|
session_dep: SessionDependency,
|
||||||
|
body_field: str,
|
||||||
|
) -> ModelType:
|
||||||
|
"""Create a dependency that fetches a DB object from a body field.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: SQLAlchemy model class
|
||||||
|
field: Model field to filter by (e.g., User.id)
|
||||||
|
session_dep: Session dependency function (e.g., get_db)
|
||||||
|
body_field: Name of the field in the request body
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Depends() instance that resolves to the model instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If no matching record is found
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
UserDep = BodyDependency(
|
||||||
|
User, User.ctfd_id, session_dep=get_db, body_field="user_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post("/assign")
|
||||||
|
async def assign(
|
||||||
|
user: User = UserDep,
|
||||||
|
): ...
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
session_callable = _unwrap_session_dep(session_dep)
|
||||||
|
crud = CrudFactory(model)
|
||||||
|
python_type = field.type.python_type
|
||||||
|
|
||||||
|
async def dependency(
|
||||||
|
session: AsyncSession = Depends(session_callable), **kwargs: Any
|
||||||
|
) -> ModelType:
|
||||||
|
value = kwargs[body_field]
|
||||||
|
return await crud.get(session, filters=[field == value])
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
dependency,
|
||||||
|
"__signature__",
|
||||||
|
inspect.Signature(
|
||||||
|
parameters=[
|
||||||
|
inspect.Parameter(
|
||||||
|
body_field, inspect.Parameter.KEYWORD_ONLY, annotation=python_type
|
||||||
|
),
|
||||||
|
inspect.Parameter(
|
||||||
|
"session",
|
||||||
|
inspect.Parameter.KEYWORD_ONLY,
|
||||||
|
annotation=AsyncSession,
|
||||||
|
default=Depends(session_callable),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return cast(ModelType, Depends(cast(Callable[..., ModelType], dependency)))
|
||||||
@@ -1,19 +1,33 @@
|
|||||||
|
"""Standardized API exceptions and error response handlers."""
|
||||||
|
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
|
ApiError,
|
||||||
ApiException,
|
ApiException,
|
||||||
ConflictError,
|
ConflictError,
|
||||||
ForbiddenError,
|
ForbiddenError,
|
||||||
|
InvalidFacetFilterError,
|
||||||
|
InvalidOrderFieldError,
|
||||||
|
InvalidSearchColumnError,
|
||||||
|
NoSearchableFieldsError,
|
||||||
NotFoundError,
|
NotFoundError,
|
||||||
UnauthorizedError,
|
UnauthorizedError,
|
||||||
|
UnsupportedFacetTypeError,
|
||||||
generate_error_responses,
|
generate_error_responses,
|
||||||
)
|
)
|
||||||
from .handler import init_exceptions_handlers
|
from .handler import init_exceptions_handlers
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"init_exceptions_handlers",
|
"ApiError",
|
||||||
"generate_error_responses",
|
|
||||||
"ApiException",
|
"ApiException",
|
||||||
"ConflictError",
|
"ConflictError",
|
||||||
"ForbiddenError",
|
"ForbiddenError",
|
||||||
|
"generate_error_responses",
|
||||||
|
"init_exceptions_handlers",
|
||||||
|
"InvalidFacetFilterError",
|
||||||
|
"InvalidOrderFieldError",
|
||||||
|
"InvalidSearchColumnError",
|
||||||
|
"NoSearchableFieldsError",
|
||||||
"NotFoundError",
|
"NotFoundError",
|
||||||
"UnauthorizedError",
|
"UnauthorizedError",
|
||||||
|
"UnsupportedFacetTypeError",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -6,30 +6,46 @@ from ..schemas import ApiError, ErrorResponse, ResponseStatus
|
|||||||
|
|
||||||
|
|
||||||
class ApiException(Exception):
|
class ApiException(Exception):
|
||||||
"""Base exception for API errors with structured response.
|
"""Base exception for API errors with structured response."""
|
||||||
|
|
||||||
Subclass this to create custom API exceptions with consistent error format.
|
|
||||||
The exception handler will use api_error to generate the response.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
class CustomError(ApiException):
|
|
||||||
api_error = ApiError(
|
|
||||||
code=400,
|
|
||||||
msg="Bad Request",
|
|
||||||
desc="The request was invalid.",
|
|
||||||
err_code="CUSTOM-400",
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
api_error: ClassVar[ApiError]
|
api_error: ClassVar[ApiError]
|
||||||
|
|
||||||
def __init__(self, detail: str | None = None):
|
def __init_subclass__(cls, abstract: bool = False, **kwargs: Any) -> None:
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
if not abstract and not hasattr(cls, "api_error"):
|
||||||
|
raise TypeError(
|
||||||
|
f"{cls.__name__} must define an 'api_error' class attribute. "
|
||||||
|
"Pass abstract=True when creating intermediate base classes."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
detail: str | None = None,
|
||||||
|
*,
|
||||||
|
desc: str | None = None,
|
||||||
|
data: Any = None,
|
||||||
|
) -> None:
|
||||||
"""Initialize the exception.
|
"""Initialize the exception.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
detail: Optional override for the error message
|
detail: Optional human-readable message
|
||||||
|
desc: Optional per-instance override for the ``description`` field
|
||||||
|
in the HTTP response body.
|
||||||
|
data: Optional per-instance override for the ``data`` field in the
|
||||||
|
HTTP response body.
|
||||||
"""
|
"""
|
||||||
super().__init__(detail or self.api_error.msg)
|
updates: dict[str, Any] = {}
|
||||||
|
if detail is not None:
|
||||||
|
updates["msg"] = detail
|
||||||
|
if desc is not None:
|
||||||
|
updates["desc"] = desc
|
||||||
|
if data is not None:
|
||||||
|
updates["data"] = data
|
||||||
|
if updates:
|
||||||
|
object.__setattr__(
|
||||||
|
self, "api_error", self.__class__.api_error.model_copy(update=updates)
|
||||||
|
)
|
||||||
|
super().__init__(self.api_error.msg)
|
||||||
|
|
||||||
|
|
||||||
class UnauthorizedError(ApiException):
|
class UnauthorizedError(ApiException):
|
||||||
@@ -76,46 +92,134 @@ class ConflictError(ApiException):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InsufficientRolesError(ForbiddenError):
|
class NoSearchableFieldsError(ApiException):
|
||||||
"""User does not have the required roles."""
|
"""Raised when search is requested but no searchable fields are available."""
|
||||||
|
|
||||||
api_error = ApiError(
|
api_error = ApiError(
|
||||||
code=403,
|
code=400,
|
||||||
msg="Insufficient Roles",
|
msg="No Searchable Fields",
|
||||||
desc="You do not have the required roles to access this resource.",
|
desc="No searchable fields configured for this resource.",
|
||||||
err_code="RBAC-403",
|
err_code="SEARCH-400",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, required_roles: list[str], user_roles: set[str] | None = None):
|
def __init__(self, model: type) -> None:
|
||||||
self.required_roles = required_roles
|
"""Initialize the exception.
|
||||||
self.user_roles = user_roles
|
|
||||||
|
|
||||||
desc = f"Required roles: {', '.join(required_roles)}"
|
Args:
|
||||||
if user_roles is not None:
|
model: The model class that has no searchable fields configured.
|
||||||
desc += f". User has: {', '.join(user_roles) if user_roles else 'no roles'}"
|
"""
|
||||||
|
self.model = model
|
||||||
super().__init__(desc)
|
super().__init__(
|
||||||
|
desc=(
|
||||||
|
f"No searchable fields found for model '{model.__name__}'. "
|
||||||
class UserNotFoundError(NotFoundError):
|
"Provide 'search_fields' parameter or set 'searchable_fields' on the CRUD class."
|
||||||
"""User was not found."""
|
)
|
||||||
|
|
||||||
api_error = ApiError(
|
|
||||||
code=404,
|
|
||||||
msg="User Not Found",
|
|
||||||
desc="The requested user was not found.",
|
|
||||||
err_code="USER-404",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class RoleNotFoundError(NotFoundError):
|
class InvalidFacetFilterError(ApiException):
|
||||||
"""Role was not found."""
|
"""Raised when filter_by contains a key not declared in facet_fields."""
|
||||||
|
|
||||||
api_error = ApiError(
|
api_error = ApiError(
|
||||||
code=404,
|
code=400,
|
||||||
msg="Role Not Found",
|
msg="Invalid Facet Filter",
|
||||||
desc="The requested role was not found.",
|
desc="One or more filter_by keys are not declared as facet fields.",
|
||||||
err_code="ROLE-404",
|
err_code="FACET-400",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, key: str, valid_keys: set[str]) -> None:
|
||||||
|
"""Initialize the exception.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The unknown filter key provided by the caller.
|
||||||
|
valid_keys: Set of valid keys derived from the declared facet_fields.
|
||||||
|
"""
|
||||||
|
self.key = key
|
||||||
|
self.valid_keys = valid_keys
|
||||||
|
super().__init__(
|
||||||
|
desc=(
|
||||||
|
f"'{key}' is not a declared facet field. "
|
||||||
|
f"Valid keys: {sorted(valid_keys) or 'none — set facet_fields on the CRUD class'}."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedFacetTypeError(ApiException):
|
||||||
|
"""Raised when a facet field has a column type not supported by filter_by."""
|
||||||
|
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400,
|
||||||
|
msg="Unsupported Facet Type",
|
||||||
|
desc="The column type is not supported for facet filtering.",
|
||||||
|
err_code="FACET-TYPE-400",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, key: str, col_type: str) -> None:
|
||||||
|
"""Initialize the exception.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The facet field key.
|
||||||
|
col_type: The unsupported column type name.
|
||||||
|
"""
|
||||||
|
self.key = key
|
||||||
|
self.col_type = col_type
|
||||||
|
super().__init__(
|
||||||
|
desc=(
|
||||||
|
f"Facet field '{key}' has unsupported column type '{col_type}'. "
|
||||||
|
f"Supported types: String, Integer, Numeric, Boolean, "
|
||||||
|
f"Date, DateTime, Time, Enum, Uuid, ARRAY."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidSearchColumnError(ApiException):
|
||||||
|
"""Raised when search_column is not one of the configured searchable fields."""
|
||||||
|
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400,
|
||||||
|
msg="Invalid Search Column",
|
||||||
|
desc="The requested search column is not a configured searchable field.",
|
||||||
|
err_code="SEARCH-COL-400",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, column: str, valid_columns: list[str]) -> None:
|
||||||
|
"""Initialize the exception.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column: The unknown search column provided by the caller.
|
||||||
|
valid_columns: List of valid search column keys.
|
||||||
|
"""
|
||||||
|
self.column = column
|
||||||
|
self.valid_columns = valid_columns
|
||||||
|
super().__init__(
|
||||||
|
desc=(
|
||||||
|
f"'{column}' is not a searchable column. "
|
||||||
|
f"Valid columns: {valid_columns}."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidOrderFieldError(ApiException):
|
||||||
|
"""Raised when order_by contains a field not in the allowed order fields."""
|
||||||
|
|
||||||
|
api_error = ApiError(
|
||||||
|
code=422,
|
||||||
|
msg="Invalid Order Field",
|
||||||
|
desc="The requested order field is not allowed for this resource.",
|
||||||
|
err_code="SORT-422",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, field: str, valid_fields: list[str]) -> None:
|
||||||
|
"""Initialize the exception.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field: The unknown order field provided by the caller.
|
||||||
|
valid_fields: List of valid field names.
|
||||||
|
"""
|
||||||
|
self.field = field
|
||||||
|
self.valid_fields = valid_fields
|
||||||
|
super().__init__(
|
||||||
|
desc=f"'{field}' is not an allowed order field. Valid fields: {valid_fields}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -124,43 +228,40 @@ def generate_error_responses(
|
|||||||
) -> dict[int | str, dict[str, Any]]:
|
) -> dict[int | str, dict[str, Any]]:
|
||||||
"""Generate OpenAPI response documentation for exceptions.
|
"""Generate OpenAPI response documentation for exceptions.
|
||||||
|
|
||||||
Use this to document possible error responses for an endpoint.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
*errors: Exception classes that inherit from ApiException
|
*errors: Exception classes that inherit from ApiException.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict suitable for FastAPI's responses parameter
|
Dict suitable for FastAPI's ``responses`` parameter.
|
||||||
|
|
||||||
Example:
|
|
||||||
from fastapi_toolsets.exceptions import generate_error_responses, UnauthorizedError, ForbiddenError
|
|
||||||
|
|
||||||
@app.get(
|
|
||||||
"/admin",
|
|
||||||
responses=generate_error_responses(UnauthorizedError, ForbiddenError)
|
|
||||||
)
|
|
||||||
async def admin_endpoint():
|
|
||||||
...
|
|
||||||
"""
|
"""
|
||||||
responses: dict[int | str, dict[str, Any]] = {}
|
responses: dict[int | str, dict[str, Any]] = {}
|
||||||
|
|
||||||
for error in errors:
|
for error in errors:
|
||||||
api_error = error.api_error
|
api_error = error.api_error
|
||||||
|
code = api_error.code
|
||||||
|
|
||||||
responses[api_error.code] = {
|
if code not in responses:
|
||||||
|
responses[code] = {
|
||||||
"model": ErrorResponse,
|
"model": ErrorResponse,
|
||||||
"description": api_error.msg,
|
"description": api_error.msg,
|
||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"example": {
|
"examples": {},
|
||||||
"data": None,
|
|
||||||
"status": ResponseStatus.FAIL.value,
|
|
||||||
"message": api_error.msg,
|
|
||||||
"description": api_error.desc,
|
|
||||||
"error_code": api_error.err_code,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
responses[code]["content"]["application/json"]["examples"][
|
||||||
|
api_error.err_code
|
||||||
|
] = {
|
||||||
|
"summary": api_error.msg,
|
||||||
|
"value": {
|
||||||
|
"data": api_error.data,
|
||||||
|
"status": ResponseStatus.FAIL.value,
|
||||||
|
"message": api_error.msg,
|
||||||
|
"description": api_error.desc,
|
||||||
|
"error_code": api_error.err_code,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
return responses
|
return responses
|
||||||
|
|||||||
@@ -1,50 +1,69 @@
|
|||||||
"""Exception handlers for FastAPI applications."""
|
"""Exception handlers for FastAPI applications."""
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import FastAPI, Request, Response, status
|
from fastapi import FastAPI, Request, Response, status
|
||||||
from fastapi.exceptions import RequestValidationError, ResponseValidationError
|
from fastapi.exceptions import (
|
||||||
from fastapi.openapi.utils import get_openapi
|
HTTPException,
|
||||||
|
RequestValidationError,
|
||||||
|
ResponseValidationError,
|
||||||
|
)
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
from ..schemas import ResponseStatus
|
from ..schemas import ErrorResponse, ResponseStatus
|
||||||
from .exceptions import ApiException
|
from .exceptions import ApiException
|
||||||
|
|
||||||
|
_VALIDATION_LOCATION_PARAMS: frozenset[str] = frozenset(
|
||||||
|
{"body", "query", "path", "header", "cookie"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def init_exceptions_handlers(app: FastAPI) -> FastAPI:
|
def init_exceptions_handlers(app: FastAPI) -> FastAPI:
|
||||||
|
"""Register exception handlers and custom OpenAPI schema on a FastAPI app.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI application instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The same FastAPI instance (for chaining).
|
||||||
|
"""
|
||||||
_register_exception_handlers(app)
|
_register_exception_handlers(app)
|
||||||
app.openapi = lambda: _custom_openapi(app) # type: ignore[method-assign]
|
_original_openapi = app.openapi
|
||||||
|
app.openapi = lambda: _patched_openapi(app, _original_openapi) # type: ignore[method-assign] # ty:ignore[invalid-assignment]
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
def _register_exception_handlers(app: FastAPI) -> None:
|
def _register_exception_handlers(app: FastAPI) -> None:
|
||||||
"""Register all exception handlers on a FastAPI application.
|
"""Register all exception handlers on a FastAPI application."""
|
||||||
|
|
||||||
Args:
|
|
||||||
app: FastAPI application instance
|
|
||||||
|
|
||||||
Example:
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
init_exceptions_handlers(app)
|
|
||||||
"""
|
|
||||||
|
|
||||||
@app.exception_handler(ApiException)
|
@app.exception_handler(ApiException)
|
||||||
async def api_exception_handler(request: Request, exc: ApiException) -> Response:
|
async def api_exception_handler(request: Request, exc: ApiException) -> Response:
|
||||||
"""Handle custom API exceptions with structured response."""
|
"""Handle custom API exceptions with structured response."""
|
||||||
api_error = exc.api_error
|
api_error = exc.api_error
|
||||||
|
error_response = ErrorResponse(
|
||||||
|
data=api_error.data,
|
||||||
|
message=api_error.msg,
|
||||||
|
description=api_error.desc,
|
||||||
|
error_code=api_error.err_code,
|
||||||
|
)
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=api_error.code,
|
status_code=api_error.code,
|
||||||
content={
|
content=error_response.model_dump(),
|
||||||
"data": None,
|
)
|
||||||
"status": ResponseStatus.FAIL.value,
|
|
||||||
"message": api_error.msg,
|
@app.exception_handler(HTTPException)
|
||||||
"description": api_error.desc,
|
async def http_exception_handler(request: Request, exc: HTTPException) -> Response:
|
||||||
"error_code": api_error.err_code,
|
"""Handle Starlette/FastAPI HTTPException with a consistent error format."""
|
||||||
},
|
detail = exc.detail if isinstance(exc.detail, str) else "HTTP Error"
|
||||||
|
error_response = ErrorResponse(
|
||||||
|
message=detail,
|
||||||
|
error_code=f"HTTP-{exc.status_code}",
|
||||||
|
)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content=error_response.model_dump(),
|
||||||
|
headers=getattr(exc, "headers", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
@app.exception_handler(RequestValidationError)
|
@app.exception_handler(RequestValidationError)
|
||||||
@@ -64,15 +83,14 @@ def _register_exception_handlers(app: FastAPI) -> None:
|
|||||||
@app.exception_handler(Exception)
|
@app.exception_handler(Exception)
|
||||||
async def generic_exception_handler(request: Request, exc: Exception) -> Response:
|
async def generic_exception_handler(request: Request, exc: Exception) -> Response:
|
||||||
"""Handle all unhandled exceptions with a generic 500 response."""
|
"""Handle all unhandled exceptions with a generic 500 response."""
|
||||||
|
error_response = ErrorResponse(
|
||||||
|
message="Internal Server Error",
|
||||||
|
description="An unexpected error occurred. Please try again later.",
|
||||||
|
error_code="SERVER-500",
|
||||||
|
)
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
content={
|
content=error_response.model_dump(),
|
||||||
"data": None,
|
|
||||||
"status": ResponseStatus.FAIL.value,
|
|
||||||
"message": "Internal Server Error",
|
|
||||||
"description": "An unexpected error occurred. Please try again later.",
|
|
||||||
"error_code": "SERVER-500",
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -84,11 +102,10 @@ def _format_validation_error(
|
|||||||
formatted_errors = []
|
formatted_errors = []
|
||||||
|
|
||||||
for error in errors:
|
for error in errors:
|
||||||
field_path = ".".join(
|
locs = error["loc"]
|
||||||
str(loc)
|
if locs and locs[0] in _VALIDATION_LOCATION_PARAMS:
|
||||||
for loc in error["loc"]
|
locs = locs[1:]
|
||||||
if loc not in ("body", "query", "path", "header", "cookie")
|
field_path = ".".join(str(loc) for loc in locs)
|
||||||
)
|
|
||||||
formatted_errors.append(
|
formatted_errors.append(
|
||||||
{
|
{
|
||||||
"field": field_path or "root",
|
"field": field_path or "root",
|
||||||
@@ -97,46 +114,35 @@ def _format_validation_error(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
error_response = ErrorResponse(
|
||||||
|
data={"errors": formatted_errors},
|
||||||
|
message="Validation Error",
|
||||||
|
description=f"{len(formatted_errors)} validation error(s) detected",
|
||||||
|
error_code="VAL-422",
|
||||||
|
)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
|
||||||
content={
|
content=error_response.model_dump(),
|
||||||
"data": {"errors": formatted_errors},
|
|
||||||
"status": ResponseStatus.FAIL.value,
|
|
||||||
"message": "Validation Error",
|
|
||||||
"description": f"{len(formatted_errors)} validation error(s) detected",
|
|
||||||
"error_code": "VAL-422",
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _custom_openapi(app: FastAPI) -> dict[str, Any]:
|
def _patched_openapi(
|
||||||
"""Generate custom OpenAPI schema with standardized error format.
|
app: FastAPI, original_openapi: Callable[[], dict[str, Any]]
|
||||||
|
) -> dict[str, Any]:
|
||||||
Replaces default 422 validation error responses with the custom format.
|
"""Generate the OpenAPI schema and replace default 422 responses.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
app: FastAPI application instance
|
app: FastAPI application instance.
|
||||||
|
original_openapi: The previous ``app.openapi`` callable to delegate to.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
OpenAPI schema dict
|
Patched OpenAPI schema dict.
|
||||||
|
|
||||||
Example:
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
init_exceptions_handlers(app) # Automatically sets custom OpenAPI
|
|
||||||
"""
|
"""
|
||||||
if app.openapi_schema:
|
if app.openapi_schema:
|
||||||
return app.openapi_schema
|
return app.openapi_schema
|
||||||
|
|
||||||
openapi_schema = get_openapi(
|
openapi_schema = original_openapi()
|
||||||
title=app.title,
|
|
||||||
version=app.version,
|
|
||||||
openapi_version=app.openapi_version,
|
|
||||||
description=app.description,
|
|
||||||
routes=app.routes,
|
|
||||||
)
|
|
||||||
|
|
||||||
for path_data in openapi_schema.get("paths", {}).values():
|
for path_data in openapi_schema.get("paths", {}).values():
|
||||||
for operation in path_data.values():
|
for operation in path_data.values():
|
||||||
@@ -146,7 +152,10 @@ def _custom_openapi(app: FastAPI) -> dict[str, Any]:
|
|||||||
"description": "Validation Error",
|
"description": "Validation Error",
|
||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"example": {
|
"examples": {
|
||||||
|
"VAL-422": {
|
||||||
|
"summary": "Validation Error",
|
||||||
|
"value": {
|
||||||
"data": {
|
"data": {
|
||||||
"errors": [
|
"errors": [
|
||||||
{
|
{
|
||||||
@@ -160,6 +169,8 @@ def _custom_openapi(app: FastAPI) -> dict[str, Any]:
|
|||||||
"message": "Validation Error",
|
"message": "Validation Error",
|
||||||
"description": "1 validation error(s) detected",
|
"description": "1 validation error(s) detected",
|
||||||
"error_code": "VAL-422",
|
"error_code": "VAL-422",
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -1,11 +1,8 @@
|
|||||||
from .fixtures import (
|
"""Fixture system for seeding databases with dependency resolution."""
|
||||||
Context,
|
|
||||||
FixtureRegistry,
|
from .enum import LoadStrategy
|
||||||
LoadStrategy,
|
from .registry import Context, FixtureRegistry
|
||||||
load_fixtures,
|
from .utils import get_obj_by_attr, load_fixtures, load_fixtures_by_context
|
||||||
load_fixtures_by_context,
|
|
||||||
)
|
|
||||||
from .utils import get_obj_by_attr
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Context",
|
"Context",
|
||||||
@@ -16,12 +13,3 @@ __all__ = [
|
|||||||
"load_fixtures_by_context",
|
"load_fixtures_by_context",
|
||||||
"register_fixtures",
|
"register_fixtures",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# We lazy-load register_fixtures to avoid needing pytest when using fixtures CLI
|
|
||||||
def __getattr__(name: str):
|
|
||||||
if name == "register_fixtures":
|
|
||||||
from .pytest_plugin import register_fixtures
|
|
||||||
|
|
||||||
return register_fixtures
|
|
||||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
||||||
|
|||||||
32
src/fastapi_toolsets/fixtures/enum.py
Normal file
32
src/fastapi_toolsets/fixtures/enum.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""Enums for fixture loading strategies and contexts."""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class LoadStrategy(str, Enum):
|
||||||
|
"""Strategy for loading fixtures into the database."""
|
||||||
|
|
||||||
|
INSERT = "insert"
|
||||||
|
"""Insert new records. Fails if record already exists."""
|
||||||
|
|
||||||
|
MERGE = "merge"
|
||||||
|
"""Insert or update based on primary key (SQLAlchemy merge)."""
|
||||||
|
|
||||||
|
SKIP_EXISTING = "skip_existing"
|
||||||
|
"""Insert only if record doesn't exist (based on primary key)."""
|
||||||
|
|
||||||
|
|
||||||
|
class Context(str, Enum):
|
||||||
|
"""Predefined fixture contexts."""
|
||||||
|
|
||||||
|
BASE = "base"
|
||||||
|
"""Base fixtures loaded in all environments."""
|
||||||
|
|
||||||
|
PRODUCTION = "production"
|
||||||
|
"""Production-only fixtures."""
|
||||||
|
|
||||||
|
DEVELOPMENT = "development"
|
||||||
|
"""Development fixtures."""
|
||||||
|
|
||||||
|
TESTING = "testing"
|
||||||
|
"""Test fixtures."""
|
||||||
@@ -1,321 +0,0 @@
|
|||||||
"""Fixture system with dependency management and context support."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from collections.abc import Callable, Sequence
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, cast
|
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
|
||||||
|
|
||||||
from ..db import get_transaction
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class LoadStrategy(str, Enum):
|
|
||||||
"""Strategy for loading fixtures into the database."""
|
|
||||||
|
|
||||||
INSERT = "insert"
|
|
||||||
"""Insert new records. Fails if record already exists."""
|
|
||||||
|
|
||||||
MERGE = "merge"
|
|
||||||
"""Insert or update based on primary key (SQLAlchemy merge)."""
|
|
||||||
|
|
||||||
SKIP_EXISTING = "skip_existing"
|
|
||||||
"""Insert only if record doesn't exist (based on primary key)."""
|
|
||||||
|
|
||||||
|
|
||||||
class Context(str, Enum):
|
|
||||||
"""Predefined fixture contexts."""
|
|
||||||
|
|
||||||
BASE = "base"
|
|
||||||
"""Base fixtures loaded in all environments."""
|
|
||||||
|
|
||||||
PRODUCTION = "production"
|
|
||||||
"""Production-only fixtures."""
|
|
||||||
|
|
||||||
DEVELOPMENT = "development"
|
|
||||||
"""Development fixtures."""
|
|
||||||
|
|
||||||
TESTING = "testing"
|
|
||||||
"""Test fixtures."""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Fixture:
|
|
||||||
"""A fixture definition with metadata."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
func: Callable[[], Sequence[DeclarativeBase]]
|
|
||||||
depends_on: list[str] = field(default_factory=list)
|
|
||||||
contexts: list[str] = field(default_factory=lambda: [Context.BASE])
|
|
||||||
|
|
||||||
|
|
||||||
class FixtureRegistry:
|
|
||||||
"""Registry for managing fixtures with dependencies.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
from fastapi_toolsets.fixtures import FixtureRegistry, Context
|
|
||||||
|
|
||||||
fixtures = FixtureRegistry()
|
|
||||||
|
|
||||||
@fixtures.register
|
|
||||||
def roles():
|
|
||||||
return [
|
|
||||||
Role(id=1, name="admin"),
|
|
||||||
Role(id=2, name="user"),
|
|
||||||
]
|
|
||||||
|
|
||||||
@fixtures.register(depends_on=["roles"])
|
|
||||||
def users():
|
|
||||||
return [
|
|
||||||
User(id=1, username="admin", role_id=1),
|
|
||||||
]
|
|
||||||
|
|
||||||
@fixtures.register(depends_on=["users"], contexts=[Context.TESTING])
|
|
||||||
def test_data():
|
|
||||||
return [
|
|
||||||
Post(id=1, title="Test", user_id=1),
|
|
||||||
]
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._fixtures: dict[str, Fixture] = {}
|
|
||||||
|
|
||||||
def register(
|
|
||||||
self,
|
|
||||||
func: Callable[[], Sequence[DeclarativeBase]] | None = None,
|
|
||||||
*,
|
|
||||||
name: str | None = None,
|
|
||||||
depends_on: list[str] | None = None,
|
|
||||||
contexts: list[str | Context] | None = None,
|
|
||||||
) -> Callable[..., Any]:
|
|
||||||
"""Register a fixture function.
|
|
||||||
|
|
||||||
Can be used as a decorator with or without arguments.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
func: Fixture function returning list of model instances
|
|
||||||
name: Fixture name (defaults to function name)
|
|
||||||
depends_on: List of fixture names this depends on
|
|
||||||
contexts: List of contexts this fixture belongs to
|
|
||||||
|
|
||||||
Example:
|
|
||||||
@fixtures.register
|
|
||||||
def roles():
|
|
||||||
return [Role(id=1, name="admin")]
|
|
||||||
|
|
||||||
@fixtures.register(depends_on=["roles"], contexts=[Context.TESTING])
|
|
||||||
def test_users():
|
|
||||||
return [User(id=1, username="test", role_id=1)]
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(
|
|
||||||
fn: Callable[[], Sequence[DeclarativeBase]],
|
|
||||||
) -> Callable[[], Sequence[DeclarativeBase]]:
|
|
||||||
fixture_name = name or cast(Any, fn).__name__
|
|
||||||
fixture_contexts = [
|
|
||||||
c.value if isinstance(c, Context) else c
|
|
||||||
for c in (contexts or [Context.BASE])
|
|
||||||
]
|
|
||||||
|
|
||||||
self._fixtures[fixture_name] = Fixture(
|
|
||||||
name=fixture_name,
|
|
||||||
func=fn,
|
|
||||||
depends_on=depends_on or [],
|
|
||||||
contexts=fixture_contexts,
|
|
||||||
)
|
|
||||||
return fn
|
|
||||||
|
|
||||||
if func is not None:
|
|
||||||
return decorator(func)
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
def get(self, name: str) -> Fixture:
|
|
||||||
"""Get a fixture by name."""
|
|
||||||
if name not in self._fixtures:
|
|
||||||
raise KeyError(f"Fixture '{name}' not found")
|
|
||||||
return self._fixtures[name]
|
|
||||||
|
|
||||||
def get_all(self) -> list[Fixture]:
|
|
||||||
"""Get all registered fixtures."""
|
|
||||||
return list(self._fixtures.values())
|
|
||||||
|
|
||||||
def get_by_context(self, *contexts: str | Context) -> list[Fixture]:
|
|
||||||
"""Get fixtures for specific contexts."""
|
|
||||||
context_values = {c.value if isinstance(c, Context) else c for c in contexts}
|
|
||||||
return [f for f in self._fixtures.values() if set(f.contexts) & context_values]
|
|
||||||
|
|
||||||
def resolve_dependencies(self, *names: str) -> list[str]:
|
|
||||||
"""Resolve fixture dependencies in topological order.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
*names: Fixture names to resolve
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of fixture names in load order (dependencies first)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
KeyError: If a fixture is not found
|
|
||||||
ValueError: If circular dependency detected
|
|
||||||
"""
|
|
||||||
resolved: list[str] = []
|
|
||||||
seen: set[str] = set()
|
|
||||||
visiting: set[str] = set()
|
|
||||||
|
|
||||||
def visit(name: str) -> None:
|
|
||||||
if name in resolved:
|
|
||||||
return
|
|
||||||
if name in visiting:
|
|
||||||
raise ValueError(f"Circular dependency detected: {name}")
|
|
||||||
|
|
||||||
visiting.add(name)
|
|
||||||
fixture = self.get(name)
|
|
||||||
|
|
||||||
for dep in fixture.depends_on:
|
|
||||||
visit(dep)
|
|
||||||
|
|
||||||
visiting.remove(name)
|
|
||||||
resolved.append(name)
|
|
||||||
seen.add(name)
|
|
||||||
|
|
||||||
for name in names:
|
|
||||||
visit(name)
|
|
||||||
|
|
||||||
return resolved
|
|
||||||
|
|
||||||
def resolve_context_dependencies(self, *contexts: str | Context) -> list[str]:
|
|
||||||
"""Resolve all fixtures for contexts with dependencies.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
*contexts: Contexts to load
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of fixture names in load order
|
|
||||||
"""
|
|
||||||
context_fixtures = self.get_by_context(*contexts)
|
|
||||||
names = [f.name for f in context_fixtures]
|
|
||||||
|
|
||||||
all_deps: set[str] = set()
|
|
||||||
for name in names:
|
|
||||||
deps = self.resolve_dependencies(name)
|
|
||||||
all_deps.update(deps)
|
|
||||||
|
|
||||||
return self.resolve_dependencies(*all_deps)
|
|
||||||
|
|
||||||
|
|
||||||
async def load_fixtures(
|
|
||||||
session: AsyncSession,
|
|
||||||
registry: FixtureRegistry,
|
|
||||||
*names: str,
|
|
||||||
strategy: LoadStrategy = LoadStrategy.MERGE,
|
|
||||||
) -> dict[str, list[DeclarativeBase]]:
|
|
||||||
"""Load specific fixtures by name with dependencies.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: Database session
|
|
||||||
registry: Fixture registry
|
|
||||||
*names: Fixture names to load (dependencies auto-resolved)
|
|
||||||
strategy: How to handle existing records
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict mapping fixture names to loaded instances
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# Loads 'roles' first (dependency), then 'users'
|
|
||||||
result = await load_fixtures(session, fixtures, "users")
|
|
||||||
print(result["users"]) # [User(...), ...]
|
|
||||||
"""
|
|
||||||
ordered = registry.resolve_dependencies(*names)
|
|
||||||
return await _load_ordered(session, registry, ordered, strategy)
|
|
||||||
|
|
||||||
|
|
||||||
async def load_fixtures_by_context(
|
|
||||||
session: AsyncSession,
|
|
||||||
registry: FixtureRegistry,
|
|
||||||
*contexts: str | Context,
|
|
||||||
strategy: LoadStrategy = LoadStrategy.MERGE,
|
|
||||||
) -> dict[str, list[DeclarativeBase]]:
|
|
||||||
"""Load all fixtures for specific contexts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: Database session
|
|
||||||
registry: Fixture registry
|
|
||||||
*contexts: Contexts to load (e.g., Context.BASE, Context.TESTING)
|
|
||||||
strategy: How to handle existing records
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict mapping fixture names to loaded instances
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# Load base + testing fixtures
|
|
||||||
await load_fixtures_by_context(
|
|
||||||
session, fixtures,
|
|
||||||
Context.BASE, Context.TESTING
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
ordered = registry.resolve_context_dependencies(*contexts)
|
|
||||||
return await _load_ordered(session, registry, ordered, strategy)
|
|
||||||
|
|
||||||
|
|
||||||
async def _load_ordered(
|
|
||||||
session: AsyncSession,
|
|
||||||
registry: FixtureRegistry,
|
|
||||||
ordered_names: list[str],
|
|
||||||
strategy: LoadStrategy,
|
|
||||||
) -> dict[str, list[DeclarativeBase]]:
|
|
||||||
"""Load fixtures in order."""
|
|
||||||
results: dict[str, list[DeclarativeBase]] = {}
|
|
||||||
|
|
||||||
for name in ordered_names:
|
|
||||||
fixture = registry.get(name)
|
|
||||||
instances = list(fixture.func())
|
|
||||||
|
|
||||||
if not instances:
|
|
||||||
results[name] = []
|
|
||||||
continue
|
|
||||||
|
|
||||||
model_name = type(instances[0]).__name__
|
|
||||||
loaded: list[DeclarativeBase] = []
|
|
||||||
|
|
||||||
async with get_transaction(session):
|
|
||||||
for instance in instances:
|
|
||||||
if strategy == LoadStrategy.INSERT:
|
|
||||||
session.add(instance)
|
|
||||||
loaded.append(instance)
|
|
||||||
|
|
||||||
elif strategy == LoadStrategy.MERGE:
|
|
||||||
merged = await session.merge(instance)
|
|
||||||
loaded.append(merged)
|
|
||||||
|
|
||||||
elif strategy == LoadStrategy.SKIP_EXISTING:
|
|
||||||
pk = _get_primary_key(instance)
|
|
||||||
if pk is not None:
|
|
||||||
existing = await session.get(type(instance), pk)
|
|
||||||
if existing is None:
|
|
||||||
session.add(instance)
|
|
||||||
loaded.append(instance)
|
|
||||||
else:
|
|
||||||
session.add(instance)
|
|
||||||
loaded.append(instance)
|
|
||||||
|
|
||||||
results[name] = loaded
|
|
||||||
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
|
|
||||||
"""Get the primary key value of a model instance."""
|
|
||||||
mapper = instance.__class__.__mapper__
|
|
||||||
pk_cols = mapper.primary_key
|
|
||||||
|
|
||||||
if len(pk_cols) == 1:
|
|
||||||
return getattr(instance, pk_cols[0].name, None)
|
|
||||||
|
|
||||||
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
|
|
||||||
if all(v is not None for v in pk_values):
|
|
||||||
return pk_values
|
|
||||||
return None
|
|
||||||
311
src/fastapi_toolsets/fixtures/registry.py
Normal file
311
src/fastapi_toolsets/fixtures/registry.py
Normal file
@@ -0,0 +1,311 @@
|
|||||||
|
"""Fixture system with dependency management and context support."""
|
||||||
|
|
||||||
|
from collections.abc import Callable, Sequence
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from ..logger import get_logger
|
||||||
|
from .enum import Context
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_contexts(
|
||||||
|
contexts: list[str | Enum] | tuple[str | Enum, ...],
|
||||||
|
) -> list[str]:
|
||||||
|
"""Convert a sequence of any Enum subclass and/or plain strings to a list of strings."""
|
||||||
|
return [c.value if isinstance(c, Enum) else c for c in contexts]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Fixture:
|
||||||
|
"""A fixture definition with metadata."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
func: Callable[[], Sequence[DeclarativeBase]]
|
||||||
|
depends_on: list[str] = field(default_factory=list)
|
||||||
|
contexts: list[str] = field(default_factory=lambda: [Context.BASE])
|
||||||
|
|
||||||
|
|
||||||
|
class FixtureRegistry:
|
||||||
|
"""Registry for managing fixtures with dependencies.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.fixtures import FixtureRegistry, Context
|
||||||
|
|
||||||
|
fixtures = FixtureRegistry()
|
||||||
|
|
||||||
|
@fixtures.register
|
||||||
|
def roles():
|
||||||
|
return [
|
||||||
|
Role(id=1, name="admin"),
|
||||||
|
Role(id=2, name="user"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@fixtures.register(depends_on=["roles"])
|
||||||
|
def users():
|
||||||
|
return [
|
||||||
|
User(id=1, username="admin", role_id=1),
|
||||||
|
]
|
||||||
|
|
||||||
|
@fixtures.register(depends_on=["users"], contexts=[Context.TESTING])
|
||||||
|
def test_data():
|
||||||
|
return [
|
||||||
|
Post(id=1, title="Test", user_id=1),
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
Fixtures with the same name may be registered for **different** contexts.
|
||||||
|
When multiple contexts are loaded together, their instances are merged:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@fixtures.register(contexts=[Context.BASE])
|
||||||
|
def users():
|
||||||
|
return [User(id=1, username="admin")]
|
||||||
|
|
||||||
|
@fixtures.register(contexts=[Context.TESTING])
|
||||||
|
def users():
|
||||||
|
return [User(id=2, username="tester")]
|
||||||
|
# load_fixtures_by_context(..., Context.BASE, Context.TESTING)
|
||||||
|
# → loads both User(admin) and User(tester) under the "users" name
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
contexts: list[str | Enum] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._fixtures: dict[str, list[Fixture]] = {}
|
||||||
|
self._default_contexts: list[str] | None = (
|
||||||
|
_normalize_contexts(contexts) if contexts else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def _validate_no_context_overlap(self, name: str, new_contexts: list[str]) -> None:
|
||||||
|
"""Raise ``ValueError`` if any existing variant for *name* overlaps."""
|
||||||
|
existing_variants = self._fixtures.get(name, [])
|
||||||
|
new_set = set(new_contexts)
|
||||||
|
for variant in existing_variants:
|
||||||
|
if set(variant.contexts) & new_set:
|
||||||
|
raise ValueError(
|
||||||
|
f"Fixture '{name}' already exists in the current registry "
|
||||||
|
f"with overlapping contexts. Use distinct context sets for "
|
||||||
|
f"each variant of the same fixture name."
|
||||||
|
)
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self,
|
||||||
|
func: Callable[[], Sequence[DeclarativeBase]] | None = None,
|
||||||
|
*,
|
||||||
|
name: str | None = None,
|
||||||
|
depends_on: list[str] | None = None,
|
||||||
|
contexts: list[str | Enum] | None = None,
|
||||||
|
) -> Callable[..., Any]:
|
||||||
|
"""Register a fixture function.
|
||||||
|
|
||||||
|
Can be used as a decorator with or without arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: Fixture function returning list of model instances
|
||||||
|
name: Fixture name (defaults to function name)
|
||||||
|
depends_on: List of fixture names this depends on
|
||||||
|
contexts: List of contexts this fixture belongs to. Both
|
||||||
|
:class:`Context` enum values and plain strings are accepted.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
@fixtures.register
|
||||||
|
def roles():
|
||||||
|
return [Role(id=1, name="admin")]
|
||||||
|
|
||||||
|
@fixtures.register(depends_on=["roles"], contexts=[Context.TESTING])
|
||||||
|
def test_users():
|
||||||
|
return [User(id=1, username="test", role_id=1)]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(
|
||||||
|
fn: Callable[[], Sequence[DeclarativeBase]],
|
||||||
|
) -> Callable[[], Sequence[DeclarativeBase]]:
|
||||||
|
fixture_name = name or cast(Any, fn).__name__
|
||||||
|
if contexts is not None:
|
||||||
|
fixture_contexts = _normalize_contexts(contexts)
|
||||||
|
elif self._default_contexts is not None:
|
||||||
|
fixture_contexts = self._default_contexts
|
||||||
|
else:
|
||||||
|
fixture_contexts = [Context.BASE.value]
|
||||||
|
|
||||||
|
self._validate_no_context_overlap(fixture_name, fixture_contexts)
|
||||||
|
self._fixtures.setdefault(fixture_name, []).append(
|
||||||
|
Fixture(
|
||||||
|
name=fixture_name,
|
||||||
|
func=fn,
|
||||||
|
depends_on=depends_on or [],
|
||||||
|
contexts=fixture_contexts,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return fn
|
||||||
|
|
||||||
|
if func is not None:
|
||||||
|
return decorator(func)
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def include_registry(self, registry: "FixtureRegistry") -> None:
|
||||||
|
"""Include another `FixtureRegistry` in the same current `FixtureRegistry`.
|
||||||
|
|
||||||
|
Fixtures with the same name are allowed as long as their context sets
|
||||||
|
do not overlap. Conflicting contexts raise :class:`ValueError`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registry: The `FixtureRegistry` to include
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If a fixture name already exists with overlapping contexts
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
dev_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@dev_registry.register
|
||||||
|
def dev_data():
|
||||||
|
return [...]
|
||||||
|
|
||||||
|
registry.include_registry(registry=dev_registry)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
for name, variants in registry._fixtures.items():
|
||||||
|
for fixture in variants:
|
||||||
|
self._validate_no_context_overlap(name, fixture.contexts)
|
||||||
|
self._fixtures.setdefault(name, []).append(fixture)
|
||||||
|
|
||||||
|
def get(self, name: str) -> Fixture:
|
||||||
|
"""Get a fixture by name.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If no fixture with *name* is registered.
|
||||||
|
ValueError: If the fixture has multiple context variants — use
|
||||||
|
:meth:`get_variants` in that case.
|
||||||
|
"""
|
||||||
|
if name not in self._fixtures:
|
||||||
|
raise KeyError(f"Fixture '{name}' not found")
|
||||||
|
variants = self._fixtures[name]
|
||||||
|
if len(variants) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Fixture '{name}' has {len(variants)} context variants. "
|
||||||
|
f"Use get_variants('{name}') to retrieve them."
|
||||||
|
)
|
||||||
|
return variants[0]
|
||||||
|
|
||||||
|
def get_variants(self, name: str, *contexts: str | Enum) -> list[Fixture]:
|
||||||
|
"""Return all registered variants for *name*, optionally filtered by context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Fixture name.
|
||||||
|
*contexts: If given, only return variants whose context set
|
||||||
|
intersects with these values. Both :class:`Context` enum
|
||||||
|
values and plain strings are accepted.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching :class:`Fixture` objects (may be empty when a
|
||||||
|
context filter is applied and nothing matches).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If no fixture with *name* is registered.
|
||||||
|
"""
|
||||||
|
if name not in self._fixtures:
|
||||||
|
raise KeyError(f"Fixture '{name}' not found")
|
||||||
|
variants = self._fixtures[name]
|
||||||
|
if not contexts:
|
||||||
|
return list(variants)
|
||||||
|
context_values = set(_normalize_contexts(contexts))
|
||||||
|
return [v for v in variants if set(v.contexts) & context_values]
|
||||||
|
|
||||||
|
def get_all(self) -> list[Fixture]:
|
||||||
|
"""Get all registered fixtures (all variants of all names)."""
|
||||||
|
return [f for variants in self._fixtures.values() for f in variants]
|
||||||
|
|
||||||
|
def get_by_context(self, *contexts: str | Enum) -> list[Fixture]:
|
||||||
|
"""Get fixtures for specific contexts."""
|
||||||
|
context_values = set(_normalize_contexts(contexts))
|
||||||
|
return [
|
||||||
|
f
|
||||||
|
for variants in self._fixtures.values()
|
||||||
|
for f in variants
|
||||||
|
if set(f.contexts) & context_values
|
||||||
|
]
|
||||||
|
|
||||||
|
def resolve_dependencies(self, *names: str) -> list[str]:
|
||||||
|
"""Resolve fixture dependencies in topological order.
|
||||||
|
|
||||||
|
When a fixture name has multiple context variants, the union of all
|
||||||
|
variants' ``depends_on`` lists is used.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*names: Fixture names to resolve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of fixture names in load order (dependencies first)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If a fixture is not found
|
||||||
|
ValueError: If circular dependency detected
|
||||||
|
"""
|
||||||
|
resolved: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
visiting: set[str] = set()
|
||||||
|
|
||||||
|
def visit(name: str) -> None:
|
||||||
|
if name in resolved:
|
||||||
|
return
|
||||||
|
if name in visiting:
|
||||||
|
raise ValueError(f"Circular dependency detected: {name}")
|
||||||
|
|
||||||
|
visiting.add(name)
|
||||||
|
variants = self._fixtures.get(name)
|
||||||
|
if variants is None:
|
||||||
|
raise KeyError(f"Fixture '{name}' not found")
|
||||||
|
|
||||||
|
# Union of depends_on across all variants, preserving first-seen order.
|
||||||
|
seen_deps: set[str] = set()
|
||||||
|
all_deps: list[str] = []
|
||||||
|
for variant in variants:
|
||||||
|
for dep in variant.depends_on:
|
||||||
|
if dep not in seen_deps:
|
||||||
|
all_deps.append(dep)
|
||||||
|
seen_deps.add(dep)
|
||||||
|
|
||||||
|
for dep in all_deps:
|
||||||
|
visit(dep)
|
||||||
|
|
||||||
|
visiting.remove(name)
|
||||||
|
resolved.append(name)
|
||||||
|
seen.add(name)
|
||||||
|
|
||||||
|
for name in names:
|
||||||
|
visit(name)
|
||||||
|
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
def resolve_context_dependencies(self, *contexts: str | Enum) -> list[str]:
|
||||||
|
"""Resolve all fixtures for contexts with dependencies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*contexts: Contexts to load
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of fixture names in load order
|
||||||
|
"""
|
||||||
|
context_fixtures = self.get_by_context(*contexts)
|
||||||
|
# Deduplicate names while preserving first-seen order (a name can
|
||||||
|
# appear multiple times if it has variants in different contexts).
|
||||||
|
names = list(dict.fromkeys(f.name for f in context_fixtures))
|
||||||
|
|
||||||
|
all_deps: set[str] = set()
|
||||||
|
for name in names:
|
||||||
|
deps = self.resolve_dependencies(name)
|
||||||
|
all_deps.update(deps)
|
||||||
|
|
||||||
|
return self.resolve_dependencies(*all_deps)
|
||||||
@@ -1,14 +1,233 @@
|
|||||||
from collections.abc import Callable, Sequence
|
"""Fixture loading utilities for database seeding."""
|
||||||
from typing import Any, TypeVar
|
|
||||||
|
|
||||||
|
from collections.abc import Callable, Sequence
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import inspect as sa_inspect
|
||||||
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
T = TypeVar("T", bound=DeclarativeBase)
|
from ..db import get_transaction
|
||||||
|
from ..logger import get_logger
|
||||||
|
from ..types import ModelType
|
||||||
|
from .enum import LoadStrategy
|
||||||
|
from .registry import FixtureRegistry, _normalize_contexts
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _instance_to_dict(instance: DeclarativeBase) -> dict[str, Any]:
|
||||||
|
"""Extract column values from a model instance, skipping unset server-default columns."""
|
||||||
|
state = sa_inspect(instance)
|
||||||
|
state_dict = state.dict
|
||||||
|
result: dict[str, Any] = {}
|
||||||
|
for prop in state.mapper.column_attrs:
|
||||||
|
if prop.key not in state_dict:
|
||||||
|
continue
|
||||||
|
val = state_dict[prop.key]
|
||||||
|
if val is None:
|
||||||
|
col = prop.columns[0]
|
||||||
|
|
||||||
|
if (
|
||||||
|
col.server_default is not None
|
||||||
|
or (col.default is not None and col.default.is_callable)
|
||||||
|
or col.autoincrement is True
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
result[prop.key] = val
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _group_by_type(
|
||||||
|
instances: list[DeclarativeBase],
|
||||||
|
) -> list[tuple[type[DeclarativeBase], list[DeclarativeBase]]]:
|
||||||
|
"""Group instances by their concrete model class, preserving insertion order."""
|
||||||
|
groups: dict[type[DeclarativeBase], list[DeclarativeBase]] = {}
|
||||||
|
for instance in instances:
|
||||||
|
groups.setdefault(type(instance), []).append(instance)
|
||||||
|
return list(groups.items())
|
||||||
|
|
||||||
|
|
||||||
|
def _group_by_column_set(
|
||||||
|
dicts: list[dict[str, Any]],
|
||||||
|
instances: list[DeclarativeBase],
|
||||||
|
) -> list[tuple[list[dict[str, Any]], list[DeclarativeBase]]]:
|
||||||
|
"""Group (dict, instance) pairs by their dict key sets."""
|
||||||
|
groups: dict[
|
||||||
|
frozenset[str], tuple[list[dict[str, Any]], list[DeclarativeBase]]
|
||||||
|
] = {}
|
||||||
|
for d, inst in zip(dicts, instances):
|
||||||
|
key = frozenset(d)
|
||||||
|
if key not in groups:
|
||||||
|
groups[key] = ([], [])
|
||||||
|
groups[key][0].append(d)
|
||||||
|
groups[key][1].append(inst)
|
||||||
|
return list(groups.values())
|
||||||
|
|
||||||
|
|
||||||
|
async def _batch_insert(
|
||||||
|
session: AsyncSession,
|
||||||
|
model_cls: type[DeclarativeBase],
|
||||||
|
instances: list[DeclarativeBase],
|
||||||
|
) -> None:
|
||||||
|
"""INSERT all instances — raises on conflict (no duplicate handling)."""
|
||||||
|
dicts = [_instance_to_dict(i) for i in instances]
|
||||||
|
for group_dicts, _ in _group_by_column_set(dicts, instances):
|
||||||
|
await session.execute(pg_insert(model_cls).values(group_dicts))
|
||||||
|
|
||||||
|
|
||||||
|
async def _batch_merge(
|
||||||
|
session: AsyncSession,
|
||||||
|
model_cls: type[DeclarativeBase],
|
||||||
|
instances: list[DeclarativeBase],
|
||||||
|
) -> None:
|
||||||
|
"""UPSERT: insert new rows, update existing ones with the provided values."""
|
||||||
|
mapper = model_cls.__mapper__
|
||||||
|
pk_names = [col.name for col in mapper.primary_key]
|
||||||
|
pk_names_set = set(pk_names)
|
||||||
|
non_pk_cols = [
|
||||||
|
prop.key
|
||||||
|
for prop in mapper.column_attrs
|
||||||
|
if not any(col.name in pk_names_set for col in prop.columns)
|
||||||
|
]
|
||||||
|
|
||||||
|
dicts = [_instance_to_dict(i) for i in instances]
|
||||||
|
for group_dicts, _ in _group_by_column_set(dicts, instances):
|
||||||
|
stmt = pg_insert(model_cls).values(group_dicts)
|
||||||
|
|
||||||
|
inserted_keys = set(group_dicts[0])
|
||||||
|
update_cols = [col for col in non_pk_cols if col in inserted_keys]
|
||||||
|
|
||||||
|
if update_cols:
|
||||||
|
stmt = stmt.on_conflict_do_update(
|
||||||
|
index_elements=pk_names,
|
||||||
|
set_={col: stmt.excluded[col] for col in update_cols},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
stmt = stmt.on_conflict_do_nothing(index_elements=pk_names)
|
||||||
|
|
||||||
|
await session.execute(stmt)
|
||||||
|
|
||||||
|
|
||||||
|
async def _batch_skip_existing(
|
||||||
|
session: AsyncSession,
|
||||||
|
model_cls: type[DeclarativeBase],
|
||||||
|
instances: list[DeclarativeBase],
|
||||||
|
) -> list[DeclarativeBase]:
|
||||||
|
"""INSERT only rows that do not already exist; return the inserted ones."""
|
||||||
|
mapper = model_cls.__mapper__
|
||||||
|
pk_names = [col.name for col in mapper.primary_key]
|
||||||
|
|
||||||
|
no_pk: list[DeclarativeBase] = []
|
||||||
|
with_pk_pairs: list[tuple[DeclarativeBase, Any]] = []
|
||||||
|
for inst in instances:
|
||||||
|
pk = _get_primary_key(inst)
|
||||||
|
if pk is None:
|
||||||
|
no_pk.append(inst)
|
||||||
|
else:
|
||||||
|
with_pk_pairs.append((inst, pk))
|
||||||
|
|
||||||
|
loaded: list[DeclarativeBase] = list(no_pk)
|
||||||
|
if no_pk:
|
||||||
|
no_pk_dicts = [_instance_to_dict(i) for i in no_pk]
|
||||||
|
for group_dicts, _ in _group_by_column_set(no_pk_dicts, no_pk):
|
||||||
|
await session.execute(pg_insert(model_cls).values(group_dicts))
|
||||||
|
|
||||||
|
if with_pk_pairs:
|
||||||
|
with_pk = [i for i, _ in with_pk_pairs]
|
||||||
|
with_pk_dicts = [_instance_to_dict(i) for i in with_pk]
|
||||||
|
for group_dicts, group_insts in _group_by_column_set(with_pk_dicts, with_pk):
|
||||||
|
stmt = (
|
||||||
|
pg_insert(model_cls)
|
||||||
|
.values(group_dicts)
|
||||||
|
.on_conflict_do_nothing(index_elements=pk_names)
|
||||||
|
)
|
||||||
|
result = await session.execute(stmt.returning(*mapper.primary_key))
|
||||||
|
inserted_pks = {
|
||||||
|
row[0] if len(pk_names) == 1 else tuple(row) for row in result
|
||||||
|
}
|
||||||
|
loaded.extend(
|
||||||
|
inst
|
||||||
|
for inst, pk in zip(
|
||||||
|
group_insts, [_get_primary_key(i) for i in group_insts]
|
||||||
|
)
|
||||||
|
if pk in inserted_pks
|
||||||
|
)
|
||||||
|
|
||||||
|
return loaded
|
||||||
|
|
||||||
|
|
||||||
|
async def _load_ordered(
|
||||||
|
session: AsyncSession,
|
||||||
|
registry: FixtureRegistry,
|
||||||
|
ordered_names: list[str],
|
||||||
|
strategy: LoadStrategy,
|
||||||
|
contexts: tuple[str, ...] | None = None,
|
||||||
|
) -> dict[str, list[DeclarativeBase]]:
|
||||||
|
"""Load fixtures in order using batch Core INSERT statements."""
|
||||||
|
results: dict[str, list[DeclarativeBase]] = {}
|
||||||
|
|
||||||
|
for name in ordered_names:
|
||||||
|
variants = (
|
||||||
|
registry.get_variants(name, *contexts)
|
||||||
|
if contexts is not None
|
||||||
|
else registry.get_variants(name)
|
||||||
|
)
|
||||||
|
|
||||||
|
if contexts is not None and not variants:
|
||||||
|
variants = registry.get_variants(name)
|
||||||
|
|
||||||
|
if not variants:
|
||||||
|
results[name] = []
|
||||||
|
continue
|
||||||
|
|
||||||
|
instances = [inst for v in variants for inst in v.func()]
|
||||||
|
|
||||||
|
if not instances:
|
||||||
|
results[name] = []
|
||||||
|
continue
|
||||||
|
|
||||||
|
model_name = type(instances[0]).__name__
|
||||||
|
loaded: list[DeclarativeBase] = []
|
||||||
|
|
||||||
|
async with get_transaction(session):
|
||||||
|
for model_cls, group in _group_by_type(instances):
|
||||||
|
match strategy:
|
||||||
|
case LoadStrategy.INSERT:
|
||||||
|
await _batch_insert(session, model_cls, group)
|
||||||
|
loaded.extend(group)
|
||||||
|
case LoadStrategy.MERGE:
|
||||||
|
await _batch_merge(session, model_cls, group)
|
||||||
|
loaded.extend(group)
|
||||||
|
case LoadStrategy.SKIP_EXISTING:
|
||||||
|
inserted = await _batch_skip_existing(session, model_cls, group)
|
||||||
|
loaded.extend(inserted)
|
||||||
|
|
||||||
|
results[name] = loaded
|
||||||
|
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
|
||||||
|
"""Get the primary key value of a model instance."""
|
||||||
|
mapper = instance.__class__.__mapper__
|
||||||
|
pk_cols = mapper.primary_key
|
||||||
|
|
||||||
|
if len(pk_cols) == 1:
|
||||||
|
return getattr(instance, pk_cols[0].name, None)
|
||||||
|
|
||||||
|
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
|
||||||
|
if all(v is not None for v in pk_values):
|
||||||
|
return pk_values
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_obj_by_attr(
|
def get_obj_by_attr(
|
||||||
fixtures: Callable[[], Sequence[T]], attr_name: str, value: Any
|
fixtures: Callable[[], Sequence[ModelType]], attr_name: str, value: Any
|
||||||
) -> T:
|
) -> ModelType:
|
||||||
"""Get a SQLAlchemy model instance by matching an attribute value.
|
"""Get a SQLAlchemy model instance by matching an attribute value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -21,6 +240,59 @@ def get_obj_by_attr(
|
|||||||
The first model instance where the attribute matches the given value.
|
The first model instance where the attribute matches the given value.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
StopIteration: If no matching object is found.
|
StopIteration: If no matching object is found in the fixture group.
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
return next(obj for obj in fixtures() if getattr(obj, attr_name) == value)
|
return next(obj for obj in fixtures() if getattr(obj, attr_name) == value)
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration(
|
||||||
|
f"No object with {attr_name}={value} found in fixture '{getattr(fixtures, '__name__', repr(fixtures))}'"
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
async def load_fixtures(
|
||||||
|
session: AsyncSession,
|
||||||
|
registry: FixtureRegistry,
|
||||||
|
*names: str,
|
||||||
|
strategy: LoadStrategy = LoadStrategy.MERGE,
|
||||||
|
) -> dict[str, list[DeclarativeBase]]:
|
||||||
|
"""Load specific fixtures by name with dependencies.
|
||||||
|
|
||||||
|
All context variants of each requested fixture are loaded and merged.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
registry: Fixture registry
|
||||||
|
*names: Fixture names to load (dependencies auto-resolved)
|
||||||
|
strategy: How to handle existing records
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping fixture names to loaded instances
|
||||||
|
"""
|
||||||
|
ordered = registry.resolve_dependencies(*names)
|
||||||
|
return await _load_ordered(session, registry, ordered, strategy)
|
||||||
|
|
||||||
|
|
||||||
|
async def load_fixtures_by_context(
|
||||||
|
session: AsyncSession,
|
||||||
|
registry: FixtureRegistry,
|
||||||
|
*contexts: str | Enum,
|
||||||
|
strategy: LoadStrategy = LoadStrategy.MERGE,
|
||||||
|
) -> dict[str, list[DeclarativeBase]]:
|
||||||
|
"""Load all fixtures for specific contexts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
registry: Fixture registry
|
||||||
|
*contexts: Contexts to load (e.g., ``Context.BASE``, ``Context.TESTING``,
|
||||||
|
or plain strings for custom contexts)
|
||||||
|
strategy: How to handle existing records
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping fixture names to loaded instances
|
||||||
|
"""
|
||||||
|
context_strings = tuple(_normalize_contexts(contexts))
|
||||||
|
ordered = registry.resolve_context_dependencies(*contexts)
|
||||||
|
return await _load_ordered(
|
||||||
|
session, registry, ordered, strategy, contexts=context_strings
|
||||||
|
)
|
||||||
|
|||||||
98
src/fastapi_toolsets/logger.py
Normal file
98
src/fastapi_toolsets/logger.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""Logging configuration for FastAPI applications and CLI tools."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
__all__ = ["LogLevel", "configure_logging", "get_logger"]
|
||||||
|
|
||||||
|
DEFAULT_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
UVICORN_LOGGERS = ("uvicorn", "uvicorn.access", "uvicorn.error")
|
||||||
|
|
||||||
|
LogLevel = Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
||||||
|
|
||||||
|
|
||||||
|
def configure_logging(
|
||||||
|
level: LogLevel | int = "INFO",
|
||||||
|
fmt: str = DEFAULT_FORMAT,
|
||||||
|
logger_name: str | None = None,
|
||||||
|
) -> logging.Logger:
|
||||||
|
"""Configure logging with a stdout handler and consistent format.
|
||||||
|
|
||||||
|
Sets up a :class:`~logging.StreamHandler` writing to stdout with the
|
||||||
|
given format and level. Also configures the uvicorn loggers so that
|
||||||
|
FastAPI access logs use the same format.
|
||||||
|
|
||||||
|
Calling this function multiple times is safe -- existing handlers are
|
||||||
|
replaced rather than duplicated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: Log level (e.g. ``"DEBUG"``, ``"INFO"``, or ``logging.DEBUG``).
|
||||||
|
fmt: Log format string. Defaults to
|
||||||
|
``"%(asctime)s - %(name)s - %(levelname)s - %(message)s"``.
|
||||||
|
logger_name: Logger name to configure. ``None`` (the default)
|
||||||
|
configures the root logger so all loggers inherit the settings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The configured Logger instance.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.logger import configure_logging
|
||||||
|
|
||||||
|
logger = configure_logging("DEBUG")
|
||||||
|
logger.info("Application started")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
formatter = logging.Formatter(fmt)
|
||||||
|
|
||||||
|
handler = logging.StreamHandler(sys.stdout)
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
logger = logging.getLogger(logger_name)
|
||||||
|
logger.handlers.clear()
|
||||||
|
logger.addHandler(handler)
|
||||||
|
logger.setLevel(level)
|
||||||
|
|
||||||
|
for name in UVICORN_LOGGERS:
|
||||||
|
uv_logger = logging.getLogger(name)
|
||||||
|
uv_logger.handlers.clear()
|
||||||
|
uv_logger.addHandler(handler)
|
||||||
|
uv_logger.setLevel(level)
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
_SENTINEL = object()
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name: str | None = _SENTINEL) -> logging.Logger: # type: ignore[assignment] # ty:ignore[invalid-parameter-default]
|
||||||
|
"""Return a logger with the given *name*.
|
||||||
|
|
||||||
|
A thin convenience wrapper around :func:`logging.getLogger` that keeps
|
||||||
|
logging imports consistent across the codebase.
|
||||||
|
|
||||||
|
When called without arguments, the caller's ``__name__`` is used
|
||||||
|
automatically, so ``get_logger()`` in a module is equivalent to
|
||||||
|
``logging.getLogger(__name__)``. Pass ``None`` explicitly to get the
|
||||||
|
root logger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Logger name. Defaults to the caller's ``__name__``.
|
||||||
|
Pass ``None`` to get the root logger.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Logger instance.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger() # uses caller's __name__
|
||||||
|
logger = get_logger("myapp") # explicit name
|
||||||
|
logger = get_logger(None) # root logger
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if name is _SENTINEL:
|
||||||
|
name = sys._getframe(1).f_globals.get("__name__")
|
||||||
|
return logging.getLogger(name)
|
||||||
21
src/fastapi_toolsets/metrics/__init__.py
Normal file
21
src/fastapi_toolsets/metrics/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""Prometheus metrics integration for FastAPI applications."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .registry import Metric, MetricsRegistry
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .handler import init_metrics
|
||||||
|
except ImportError:
|
||||||
|
|
||||||
|
def init_metrics(*_args: Any, **_kwargs: Any) -> None:
|
||||||
|
from .._imports import require_extra
|
||||||
|
|
||||||
|
require_extra(package="prometheus_client", extra="metrics")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Metric",
|
||||||
|
"MetricsRegistry",
|
||||||
|
"init_metrics",
|
||||||
|
]
|
||||||
81
src/fastapi_toolsets/metrics/handler.py
Normal file
81
src/fastapi_toolsets/metrics/handler.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
"""Prometheus metrics endpoint for FastAPI applications."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import os
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.responses import Response
|
||||||
|
from prometheus_client import (
|
||||||
|
CONTENT_TYPE_LATEST,
|
||||||
|
CollectorRegistry,
|
||||||
|
generate_latest,
|
||||||
|
multiprocess,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..logger import get_logger
|
||||||
|
from .registry import MetricsRegistry
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_multiprocess() -> bool:
|
||||||
|
"""Check if prometheus multi-process mode is enabled."""
|
||||||
|
return "PROMETHEUS_MULTIPROC_DIR" in os.environ
|
||||||
|
|
||||||
|
|
||||||
|
def init_metrics(
|
||||||
|
app: FastAPI,
|
||||||
|
registry: MetricsRegistry,
|
||||||
|
*,
|
||||||
|
path: str = "/metrics",
|
||||||
|
) -> FastAPI:
|
||||||
|
"""Register a Prometheus ``/metrics`` endpoint on a FastAPI app.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI application instance.
|
||||||
|
registry: A :class:`MetricsRegistry` containing providers and collectors.
|
||||||
|
path: URL path for the metrics endpoint (default ``/metrics``).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The same FastAPI instance (for chaining).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi_toolsets.metrics import MetricsRegistry, init_metrics
|
||||||
|
|
||||||
|
metrics = MetricsRegistry()
|
||||||
|
app = FastAPI()
|
||||||
|
init_metrics(app, registry=metrics)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
for provider in registry.get_providers():
|
||||||
|
logger.debug("Initialising metric provider '%s'", provider.name)
|
||||||
|
registry._instances[provider.name] = provider.func()
|
||||||
|
|
||||||
|
# Partition collectors and cache env check at startup — both are stable for the app lifetime.
|
||||||
|
async_collectors = [
|
||||||
|
c for c in registry.get_collectors() if inspect.iscoroutinefunction(c.func)
|
||||||
|
]
|
||||||
|
sync_collectors = [
|
||||||
|
c for c in registry.get_collectors() if not inspect.iscoroutinefunction(c.func)
|
||||||
|
]
|
||||||
|
multiprocess_mode = _is_multiprocess()
|
||||||
|
|
||||||
|
@app.get(path, include_in_schema=False)
|
||||||
|
async def metrics_endpoint() -> Response:
|
||||||
|
for collector in sync_collectors:
|
||||||
|
collector.func()
|
||||||
|
for collector in async_collectors:
|
||||||
|
await collector.func()
|
||||||
|
|
||||||
|
if multiprocess_mode:
|
||||||
|
prom_registry = CollectorRegistry()
|
||||||
|
multiprocess.MultiProcessCollector(prom_registry)
|
||||||
|
output = generate_latest(prom_registry)
|
||||||
|
else:
|
||||||
|
output = generate_latest()
|
||||||
|
|
||||||
|
return Response(content=output, media_type=CONTENT_TYPE_LATEST)
|
||||||
|
|
||||||
|
return app
|
||||||
104
src/fastapi_toolsets/metrics/registry.py
Normal file
104
src/fastapi_toolsets/metrics/registry.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
"""Metrics registry with decorator-based registration."""
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from ..logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Metric:
|
||||||
|
"""A metric definition with metadata."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
func: Callable[..., Any]
|
||||||
|
collect: bool = field(default=False)
|
||||||
|
|
||||||
|
|
||||||
|
class MetricsRegistry:
|
||||||
|
"""Registry for managing Prometheus metric providers and collectors."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._metrics: dict[str, Metric] = {}
|
||||||
|
self._instances: dict[str, Any] = {}
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self,
|
||||||
|
func: Callable[..., Any] | None = None,
|
||||||
|
*,
|
||||||
|
name: str | None = None,
|
||||||
|
collect: bool = False,
|
||||||
|
) -> Callable[..., Any]:
|
||||||
|
"""Register a metric provider or collector function.
|
||||||
|
|
||||||
|
Can be used as a decorator with or without arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The metric function to register.
|
||||||
|
name: Metric name (defaults to function name).
|
||||||
|
collect: If ``True``, the function is called on every scrape.
|
||||||
|
If ``False`` (default), called once at init time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
|
metric_name = name or cast(Any, fn).__name__
|
||||||
|
self._metrics[metric_name] = Metric(
|
||||||
|
name=metric_name,
|
||||||
|
func=fn,
|
||||||
|
collect=collect,
|
||||||
|
)
|
||||||
|
return fn
|
||||||
|
|
||||||
|
if func is not None:
|
||||||
|
return decorator(func)
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def get(self, name: str) -> Any:
|
||||||
|
"""Return the metric instance created by a provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The metric name (defaults to the provider function name).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the metric name is unknown or ``init_metrics`` has not
|
||||||
|
been called yet.
|
||||||
|
"""
|
||||||
|
if name not in self._instances:
|
||||||
|
if name in self._metrics:
|
||||||
|
raise KeyError(
|
||||||
|
f"Metric '{name}' exists but has not been initialized yet. "
|
||||||
|
"Ensure init_metrics() has been called before accessing metric instances."
|
||||||
|
)
|
||||||
|
raise KeyError(f"Unknown metric '{name}'.")
|
||||||
|
return self._instances[name]
|
||||||
|
|
||||||
|
def include_registry(self, registry: "MetricsRegistry") -> None:
|
||||||
|
"""Include another :class:`MetricsRegistry` into this one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registry: The registry to merge in.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If a metric name already exists in the current registry.
|
||||||
|
"""
|
||||||
|
for metric_name, definition in registry._metrics.items():
|
||||||
|
if metric_name in self._metrics:
|
||||||
|
raise ValueError(
|
||||||
|
f"Metric '{metric_name}' already exists in the current registry"
|
||||||
|
)
|
||||||
|
self._metrics[metric_name] = definition
|
||||||
|
|
||||||
|
def get_all(self) -> list[Metric]:
|
||||||
|
"""Get all registered metric definitions."""
|
||||||
|
return list(self._metrics.values())
|
||||||
|
|
||||||
|
def get_providers(self) -> list[Metric]:
|
||||||
|
"""Get metric providers (called once at init)."""
|
||||||
|
return [m for m in self._metrics.values() if not m.collect]
|
||||||
|
|
||||||
|
def get_collectors(self) -> list[Metric]:
|
||||||
|
"""Get collectors (called on each scrape)."""
|
||||||
|
return [m for m in self._metrics.values() if m.collect]
|
||||||
21
src/fastapi_toolsets/models/__init__.py
Normal file
21
src/fastapi_toolsets/models/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""SQLAlchemy model mixins for common column patterns."""
|
||||||
|
|
||||||
|
from .columns import (
|
||||||
|
CreatedAtMixin,
|
||||||
|
TimestampMixin,
|
||||||
|
UUIDMixin,
|
||||||
|
UUIDv7Mixin,
|
||||||
|
UpdatedAtMixin,
|
||||||
|
)
|
||||||
|
from .watched import EventSession, ModelEvent, listens_for
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"EventSession",
|
||||||
|
"ModelEvent",
|
||||||
|
"UUIDMixin",
|
||||||
|
"UUIDv7Mixin",
|
||||||
|
"CreatedAtMixin",
|
||||||
|
"UpdatedAtMixin",
|
||||||
|
"TimestampMixin",
|
||||||
|
"listens_for",
|
||||||
|
]
|
||||||
50
src/fastapi_toolsets/models/columns.py
Normal file
50
src/fastapi_toolsets/models/columns.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""SQLAlchemy column mixins for common column patterns."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, Uuid, text
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
|
||||||
|
class UUIDMixin:
|
||||||
|
"""Mixin that adds a UUID primary key auto-generated by the database."""
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
Uuid,
|
||||||
|
primary_key=True,
|
||||||
|
server_default=text("gen_random_uuid()"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UUIDv7Mixin:
|
||||||
|
"""Mixin that adds a UUIDv7 primary key auto-generated by the database."""
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
Uuid,
|
||||||
|
primary_key=True,
|
||||||
|
server_default=text("uuidv7()"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CreatedAtMixin:
|
||||||
|
"""Mixin that adds a ``created_at`` timestamp column."""
|
||||||
|
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
server_default=text("clock_timestamp()"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UpdatedAtMixin:
|
||||||
|
"""Mixin that adds an ``updated_at`` timestamp column."""
|
||||||
|
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
server_default=text("clock_timestamp()"),
|
||||||
|
onupdate=text("clock_timestamp()"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TimestampMixin(CreatedAtMixin, UpdatedAtMixin):
|
||||||
|
"""Mixin that combines ``created_at`` and ``updated_at`` timestamp columns."""
|
||||||
290
src/fastapi_toolsets/models/watched.py
Normal file
290
src/fastapi_toolsets/models/watched.py
Normal file
@@ -0,0 +1,290 @@
|
|||||||
|
"""Field-change monitoring via SQLAlchemy session events."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from collections.abc import Callable
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import event
|
||||||
|
from sqlalchemy import inspect as sa_inspect
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm.attributes import set_committed_value as _sa_set_committed_value
|
||||||
|
|
||||||
|
from ..logger import get_logger
|
||||||
|
|
||||||
|
_logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class ModelEvent(str, Enum):
|
||||||
|
"""Event types dispatched by :class:`EventSession`."""
|
||||||
|
|
||||||
|
CREATE = "create"
|
||||||
|
DELETE = "delete"
|
||||||
|
UPDATE = "update"
|
||||||
|
|
||||||
|
|
||||||
|
_CALLBACK_ERROR_MSG = "Event callback raised an unhandled exception"
|
||||||
|
_SESSION_CREATES = "_ft_creates"
|
||||||
|
_SESSION_DELETES = "_ft_deletes"
|
||||||
|
_SESSION_UPDATES = "_ft_updates"
|
||||||
|
_DEFERRED_STRATEGY_KEY = (("deferred", True), ("instrument", True))
|
||||||
|
_EVENT_HANDLERS: dict[tuple[type, ModelEvent], list[Callable[..., Any]]] = {}
|
||||||
|
_WATCHED_MODELS: set[type] = set()
|
||||||
|
_WATCHED_CACHE: dict[type, bool] = {}
|
||||||
|
_HANDLER_CACHE: dict[tuple[type, ModelEvent], list[Callable[..., Any]]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _invalidate_caches() -> None:
|
||||||
|
"""Clear lookup caches after handler registration."""
|
||||||
|
_WATCHED_CACHE.clear()
|
||||||
|
_HANDLER_CACHE.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def listens_for(
|
||||||
|
model_class: type,
|
||||||
|
event_types: list[ModelEvent] | None = None,
|
||||||
|
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||||
|
"""Register a callback for one or more model lifecycle events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_class: The SQLAlchemy model class to listen on.
|
||||||
|
event_types: List of :class:`ModelEvent` values to listen for.
|
||||||
|
Defaults to all event types.
|
||||||
|
"""
|
||||||
|
evs = event_types if event_types is not None else list(ModelEvent)
|
||||||
|
|
||||||
|
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
|
for ev in evs:
|
||||||
|
_EVENT_HANDLERS.setdefault((model_class, ev), []).append(fn)
|
||||||
|
_WATCHED_MODELS.add(model_class)
|
||||||
|
_invalidate_caches()
|
||||||
|
return fn
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def _is_watched(obj: Any) -> bool:
|
||||||
|
"""Return True if *obj*'s type (or any ancestor) has registered handlers."""
|
||||||
|
cls = type(obj)
|
||||||
|
try:
|
||||||
|
return _WATCHED_CACHE[cls]
|
||||||
|
except KeyError:
|
||||||
|
result = any(klass in _WATCHED_MODELS for klass in cls.__mro__)
|
||||||
|
_WATCHED_CACHE[cls] = result
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _get_handlers(cls: type, ev: ModelEvent) -> list[Callable[..., Any]]:
|
||||||
|
"""Return registered handlers for *cls* and *ev*, walking the MRO."""
|
||||||
|
key = (cls, ev)
|
||||||
|
try:
|
||||||
|
return _HANDLER_CACHE[key]
|
||||||
|
except KeyError:
|
||||||
|
handlers: list[Callable[..., Any]] = []
|
||||||
|
for klass in cls.__mro__:
|
||||||
|
handlers.extend(_EVENT_HANDLERS.get((klass, ev), []))
|
||||||
|
_HANDLER_CACHE[key] = handlers
|
||||||
|
return handlers
|
||||||
|
|
||||||
|
|
||||||
|
def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
|
||||||
|
"""Read currently-loaded column values into a plain dict."""
|
||||||
|
state = sa_inspect(obj) # InstanceState
|
||||||
|
state_dict = state.dict
|
||||||
|
snapshot: dict[str, Any] = {}
|
||||||
|
for prop in state.mapper.column_attrs:
|
||||||
|
if prop.key in state_dict:
|
||||||
|
snapshot[prop.key] = state_dict[prop.key]
|
||||||
|
elif ( # pragma: no cover
|
||||||
|
not state.expired
|
||||||
|
and prop.strategy_key != _DEFERRED_STRATEGY_KEY
|
||||||
|
and all(
|
||||||
|
col.nullable
|
||||||
|
and col.server_default is None
|
||||||
|
and col.server_onupdate is None
|
||||||
|
for col in prop.columns
|
||||||
|
)
|
||||||
|
):
|
||||||
|
snapshot[prop.key] = None
|
||||||
|
return snapshot
|
||||||
|
|
||||||
|
|
||||||
|
def _get_watched_fields(cls: type) -> tuple[str, ...] | None:
|
||||||
|
"""Return the watched fields for *cls*."""
|
||||||
|
fields = getattr(cls, "__watched_fields__", None)
|
||||||
|
if fields is not None and (
|
||||||
|
not isinstance(fields, tuple) or not all(isinstance(f, str) for f in fields)
|
||||||
|
):
|
||||||
|
raise TypeError(
|
||||||
|
f"{cls.__name__}.__watched_fields__ must be a tuple[str, ...], "
|
||||||
|
f"got {type(fields).__name__}"
|
||||||
|
)
|
||||||
|
return fields
|
||||||
|
|
||||||
|
|
||||||
|
def _upsert_changes(
|
||||||
|
pending: dict[int, tuple[Any, dict[str, dict[str, Any]]]],
|
||||||
|
obj: Any,
|
||||||
|
changes: dict[str, dict[str, Any]],
|
||||||
|
) -> None:
|
||||||
|
"""Insert or merge *changes* into *pending* for *obj*."""
|
||||||
|
key = id(obj)
|
||||||
|
if key in pending:
|
||||||
|
existing = pending[key][1]
|
||||||
|
for field, change in changes.items():
|
||||||
|
if field in existing:
|
||||||
|
existing[field]["new"] = change["new"]
|
||||||
|
else:
|
||||||
|
existing[field] = change
|
||||||
|
else:
|
||||||
|
pending[key] = (obj, changes)
|
||||||
|
|
||||||
|
|
||||||
|
@event.listens_for(AsyncSession.sync_session_class, "after_flush")
|
||||||
|
def _after_flush(session: Any, flush_context: Any) -> None:
|
||||||
|
# New objects: capture reference. Attributes will be refreshed after commit.
|
||||||
|
for obj in session.new:
|
||||||
|
if _is_watched(obj):
|
||||||
|
session.info.setdefault(_SESSION_CREATES, []).append(obj)
|
||||||
|
|
||||||
|
# Deleted objects: snapshot now while attributes are still loaded.
|
||||||
|
for obj in session.deleted:
|
||||||
|
if _is_watched(obj):
|
||||||
|
snapshot = _snapshot_column_attrs(obj)
|
||||||
|
session.info.setdefault(_SESSION_DELETES, []).append((obj, snapshot))
|
||||||
|
|
||||||
|
# Dirty objects: read old/new from SQLAlchemy attribute history.
|
||||||
|
for obj in session.dirty:
|
||||||
|
if not _is_watched(obj):
|
||||||
|
continue
|
||||||
|
|
||||||
|
watched = _get_watched_fields(type(obj))
|
||||||
|
changes: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
|
inst_attrs = sa_inspect(obj).attrs
|
||||||
|
attrs = (
|
||||||
|
((field, inst_attrs[field]) for field in watched)
|
||||||
|
if watched is not None
|
||||||
|
else ((s.key, s) for s in inst_attrs)
|
||||||
|
)
|
||||||
|
for field, attr_state in attrs:
|
||||||
|
history = attr_state.history
|
||||||
|
if history.has_changes() and history.deleted:
|
||||||
|
changes[field] = {
|
||||||
|
"old": history.deleted[0],
|
||||||
|
"new": history.added[0] if history.added else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
if changes:
|
||||||
|
_upsert_changes(
|
||||||
|
session.info.setdefault(_SESSION_UPDATES, {}),
|
||||||
|
obj,
|
||||||
|
changes,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@event.listens_for(AsyncSession.sync_session_class, "after_rollback")
|
||||||
|
def _after_rollback(session: Any) -> None:
|
||||||
|
if session.in_transaction():
|
||||||
|
return
|
||||||
|
session.info.pop(_SESSION_CREATES, None)
|
||||||
|
session.info.pop(_SESSION_DELETES, None)
|
||||||
|
session.info.pop(_SESSION_UPDATES, None)
|
||||||
|
|
||||||
|
|
||||||
|
async def _invoke_callback(
|
||||||
|
fn: Callable[..., Any],
|
||||||
|
obj: Any,
|
||||||
|
event_type: ModelEvent,
|
||||||
|
changes: dict[str, dict[str, Any]] | None,
|
||||||
|
) -> None:
|
||||||
|
"""Call *fn* and await the result if it is awaitable."""
|
||||||
|
result = fn(obj, event_type, changes)
|
||||||
|
if inspect.isawaitable(result):
|
||||||
|
await result
|
||||||
|
|
||||||
|
|
||||||
|
class EventSession(AsyncSession):
|
||||||
|
"""AsyncSession subclass that dispatches lifecycle callbacks after commit."""
|
||||||
|
|
||||||
|
async def commit(self) -> None: # noqa: C901
|
||||||
|
await super().commit()
|
||||||
|
|
||||||
|
creates: list[Any] = self.info.pop(_SESSION_CREATES, [])
|
||||||
|
deletes: list[tuple[Any, dict[str, Any]]] = self.info.pop(_SESSION_DELETES, [])
|
||||||
|
field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = self.info.pop(
|
||||||
|
_SESSION_UPDATES, {}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not creates and not deletes and not field_changes:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Suppress transient objects (created + deleted in same transaction).
|
||||||
|
if creates and deletes:
|
||||||
|
created_ids = {id(o) for o in creates}
|
||||||
|
deleted_ids = {id(o) for o, _ in deletes}
|
||||||
|
transient_ids = created_ids & deleted_ids
|
||||||
|
if transient_ids:
|
||||||
|
creates = [o for o in creates if id(o) not in transient_ids]
|
||||||
|
deletes = [(o, s) for o, s in deletes if id(o) not in transient_ids]
|
||||||
|
field_changes = {
|
||||||
|
k: v for k, v in field_changes.items() if k not in transient_ids
|
||||||
|
}
|
||||||
|
|
||||||
|
# Suppress updates for deleted objects (row is gone, refresh would fail).
|
||||||
|
if deletes and field_changes:
|
||||||
|
deleted_ids = {id(o) for o, _ in deletes}
|
||||||
|
field_changes = {
|
||||||
|
k: v for k, v in field_changes.items() if k not in deleted_ids
|
||||||
|
}
|
||||||
|
|
||||||
|
# Suppress updates for newly created objects (CREATE-only semantics).
|
||||||
|
if creates and field_changes:
|
||||||
|
create_ids = {id(o) for o in creates}
|
||||||
|
field_changes = {
|
||||||
|
k: v for k, v in field_changes.items() if k not in create_ids
|
||||||
|
}
|
||||||
|
|
||||||
|
# Dispatch CREATE callbacks.
|
||||||
|
for obj in creates:
|
||||||
|
try:
|
||||||
|
state = sa_inspect(obj, raiseerr=False)
|
||||||
|
if (
|
||||||
|
state is None or state.detached or state.transient
|
||||||
|
): # pragma: no cover
|
||||||
|
continue
|
||||||
|
await self.refresh(obj)
|
||||||
|
for handler in _get_handlers(type(obj), ModelEvent.CREATE):
|
||||||
|
await _invoke_callback(handler, obj, ModelEvent.CREATE, None)
|
||||||
|
except Exception as exc:
|
||||||
|
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||||
|
|
||||||
|
# Dispatch DELETE callbacks (restore snapshot; row is gone).
|
||||||
|
for obj, snapshot in deletes:
|
||||||
|
try:
|
||||||
|
for key, value in snapshot.items():
|
||||||
|
_sa_set_committed_value(obj, key, value)
|
||||||
|
for handler in _get_handlers(type(obj), ModelEvent.DELETE):
|
||||||
|
await _invoke_callback(handler, obj, ModelEvent.DELETE, None)
|
||||||
|
except Exception as exc:
|
||||||
|
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||||
|
|
||||||
|
# Dispatch UPDATE callbacks.
|
||||||
|
for obj, changes in field_changes.values():
|
||||||
|
try:
|
||||||
|
state = sa_inspect(obj, raiseerr=False)
|
||||||
|
if (
|
||||||
|
state is None or state.detached or state.transient
|
||||||
|
): # pragma: no cover
|
||||||
|
continue
|
||||||
|
await self.refresh(obj)
|
||||||
|
for handler in _get_handlers(type(obj), ModelEvent.UPDATE):
|
||||||
|
await _invoke_callback(handler, obj, ModelEvent.UPDATE, changes)
|
||||||
|
except Exception as exc:
|
||||||
|
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||||
|
|
||||||
|
async def rollback(self) -> None:
|
||||||
|
await super().rollback()
|
||||||
|
self.info.pop(_SESSION_CREATES, None)
|
||||||
|
self.info.pop(_SESSION_DELETES, None)
|
||||||
|
self.info.pop(_SESSION_UPDATES, None)
|
||||||
30
src/fastapi_toolsets/pytest/__init__.py
Normal file
30
src/fastapi_toolsets/pytest/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""Pytest helpers for FastAPI testing: sessions, clients, and fixtures."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .plugin import register_fixtures
|
||||||
|
except ImportError:
|
||||||
|
from .._imports import require_extra
|
||||||
|
|
||||||
|
require_extra(package="pytest", extra="pytest")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .utils import (
|
||||||
|
cleanup_tables,
|
||||||
|
create_async_client,
|
||||||
|
create_db_session,
|
||||||
|
create_worker_database,
|
||||||
|
worker_database_url,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
from .._imports import require_extra
|
||||||
|
|
||||||
|
require_extra(package="httpx", extra="pytest")
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"cleanup_tables",
|
||||||
|
"create_async_client",
|
||||||
|
"create_db_session",
|
||||||
|
"create_worker_database",
|
||||||
|
"register_fixtures",
|
||||||
|
"worker_database_url",
|
||||||
|
]
|
||||||
@@ -1,55 +1,4 @@
|
|||||||
"""Pytest plugin for using FixtureRegistry fixtures in tests.
|
"""Pytest plugin for using FixtureRegistry fixtures in tests."""
|
||||||
|
|
||||||
This module provides utilities to automatically generate pytest fixtures
|
|
||||||
from your FixtureRegistry, with proper dependency resolution.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# conftest.py
|
|
||||||
import pytest
|
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
|
||||||
|
|
||||||
from app.fixtures import fixtures # Your FixtureRegistry
|
|
||||||
from app.models import Base
|
|
||||||
from fastapi_toolsets.pytest_plugin import register_fixtures
|
|
||||||
|
|
||||||
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/test_db"
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def engine():
|
|
||||||
engine = create_async_engine(DATABASE_URL)
|
|
||||||
yield engine
|
|
||||||
await engine.dispose()
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def db_session(engine):
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
|
||||||
|
|
||||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
||||||
session = session_factory()
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield session
|
|
||||||
finally:
|
|
||||||
await session.close()
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
await conn.run_sync(Base.metadata.drop_all)
|
|
||||||
|
|
||||||
# Automatically generate pytest fixtures from registry
|
|
||||||
# Creates: fixture_roles, fixture_users, fixture_posts, etc.
|
|
||||||
register_fixtures(fixtures, globals())
|
|
||||||
|
|
||||||
Usage in tests:
|
|
||||||
# test_users.py
|
|
||||||
async def test_user_count(db_session, fixture_users):
|
|
||||||
# fixture_users automatically loads fixture_roles first (if dependency)
|
|
||||||
# and returns the list of User models
|
|
||||||
assert len(fixture_users) > 0
|
|
||||||
|
|
||||||
async def test_user_role(db_session, fixture_users):
|
|
||||||
user = fixture_users[0]
|
|
||||||
assert user.role_id is not None
|
|
||||||
"""
|
|
||||||
|
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -59,7 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
from ..db import get_transaction
|
from ..db import get_transaction
|
||||||
from .fixtures import FixtureRegistry, LoadStrategy
|
from ..fixtures import FixtureRegistry, LoadStrategy
|
||||||
|
|
||||||
|
|
||||||
def register_fixtures(
|
def register_fixtures(
|
||||||
@@ -86,6 +35,7 @@ def register_fixtures(
|
|||||||
List of created fixture names
|
List of created fixture names
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
# conftest.py
|
# conftest.py
|
||||||
from app.fixtures import fixtures
|
from app.fixtures import fixtures
|
||||||
from fastapi_toolsets.pytest_plugin import register_fixtures
|
from fastapi_toolsets.pytest_plugin import register_fixtures
|
||||||
@@ -96,6 +46,7 @@ def register_fixtures(
|
|||||||
# - fixture_roles
|
# - fixture_roles
|
||||||
# - fixture_users (depends on fixture_roles if users depends on roles)
|
# - fixture_users (depends on fixture_roles if users depends on roles)
|
||||||
# - fixture_posts (depends on fixture_users if posts depends on users)
|
# - fixture_posts (depends on fixture_users if posts depends on users)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
created_fixtures: list[str] = []
|
created_fixtures: list[str] = []
|
||||||
|
|
||||||
260
src/fastapi_toolsets/pytest/utils.py
Normal file
260
src/fastapi_toolsets/pytest/utils.py
Normal file
@@ -0,0 +1,260 @@
|
|||||||
|
"""Pytest helper utilities for FastAPI testing."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from collections.abc import AsyncGenerator, Callable
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.engine import make_url
|
||||||
|
from sqlalchemy.ext.asyncio import (
|
||||||
|
AsyncSession,
|
||||||
|
async_sessionmaker,
|
||||||
|
create_async_engine,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from ..db import cleanup_tables, create_database
|
||||||
|
from ..models.watched import EventSession
|
||||||
|
|
||||||
|
|
||||||
|
def _get_xdist_worker(default_test_db: str) -> str:
|
||||||
|
"""Return the pytest-xdist worker name, or *default_test_db* when not running under xdist.
|
||||||
|
|
||||||
|
Reads the ``PYTEST_XDIST_WORKER`` environment variable that xdist sets
|
||||||
|
automatically in each worker process (e.g. ``"gw0"``, ``"gw1"``).
|
||||||
|
When xdist is not installed or not active, the variable is absent and
|
||||||
|
*default_test_db* is returned instead.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
default_test_db: Fallback value returned when ``PYTEST_XDIST_WORKER``
|
||||||
|
is not set.
|
||||||
|
"""
|
||||||
|
return os.environ.get("PYTEST_XDIST_WORKER", default_test_db)
|
||||||
|
|
||||||
|
|
||||||
|
def worker_database_url(database_url: str, default_test_db: str) -> str:
|
||||||
|
"""Derive a per-worker database URL for pytest-xdist parallel runs.
|
||||||
|
|
||||||
|
Appends ``_{worker_name}`` to the database name so each xdist worker
|
||||||
|
operates on its own database. When not running under xdist,
|
||||||
|
``_{default_test_db}`` is appended instead.
|
||||||
|
|
||||||
|
The worker name is read from the ``PYTEST_XDIST_WORKER`` environment
|
||||||
|
variable (set automatically by xdist in each worker process).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
database_url: Original database connection URL.
|
||||||
|
default_test_db: Suffix appended to the database name when
|
||||||
|
``PYTEST_XDIST_WORKER`` is not set.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A database URL with a worker- or default-specific database name.
|
||||||
|
"""
|
||||||
|
worker = _get_xdist_worker(default_test_db=default_test_db)
|
||||||
|
|
||||||
|
url = make_url(database_url)
|
||||||
|
url = url.set(database=f"{url.database}_{worker}")
|
||||||
|
return url.render_as_string(hide_password=False)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def create_worker_database(
|
||||||
|
database_url: str,
|
||||||
|
default_test_db: str = "test_db",
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Create and drop a per-worker database for pytest-xdist isolation.
|
||||||
|
|
||||||
|
Derives a worker-specific database URL using :func:`worker_database_url`,
|
||||||
|
then delegates to :func:`~fastapi_toolsets.db.create_database` to create
|
||||||
|
and drop it. Intended for use as a **session-scoped** fixture.
|
||||||
|
|
||||||
|
When running under xdist the database name is suffixed with the worker
|
||||||
|
name (e.g. ``_gw0``). Otherwise it is suffixed with *default_test_db*.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
database_url: Original database connection URL (used as the server
|
||||||
|
connection and as the base for the worker database name).
|
||||||
|
default_test_db: Suffix appended to the database name when
|
||||||
|
``PYTEST_XDIST_WORKER`` is not set. Defaults to ``"test_db"``.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The worker-specific database URL.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import create_worker_database, create_db_session
|
||||||
|
|
||||||
|
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost/test_db"
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
async def worker_db_url():
|
||||||
|
async with create_worker_database(DATABASE_URL) as url:
|
||||||
|
yield url
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session(worker_db_url):
|
||||||
|
async with create_db_session(
|
||||||
|
worker_db_url, Base, cleanup=True
|
||||||
|
) as session:
|
||||||
|
yield session
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
worker_url = worker_database_url(
|
||||||
|
database_url=database_url, default_test_db=default_test_db
|
||||||
|
)
|
||||||
|
worker_db_name = make_url(worker_url).database
|
||||||
|
assert worker_db_name is not None
|
||||||
|
|
||||||
|
engine = create_async_engine(database_url, isolation_level="AUTOCOMMIT")
|
||||||
|
try:
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
await conn.execute(text(f"DROP DATABASE IF EXISTS {worker_db_name}"))
|
||||||
|
await create_database(db_name=worker_db_name, server_url=database_url)
|
||||||
|
|
||||||
|
yield worker_url
|
||||||
|
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
await conn.execute(text(f"DROP DATABASE IF EXISTS {worker_db_name}"))
|
||||||
|
finally:
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def create_async_client(
|
||||||
|
app: Any,
|
||||||
|
base_url: str = "http://test",
|
||||||
|
dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] | None = None,
|
||||||
|
) -> AsyncGenerator[AsyncClient, None]:
|
||||||
|
"""Create an async httpx client for testing FastAPI applications.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI application instance.
|
||||||
|
base_url: Base URL for requests. Defaults to "http://test".
|
||||||
|
dependency_overrides: Optional mapping of original dependencies to
|
||||||
|
their test replacements. Applied via ``app.dependency_overrides``
|
||||||
|
before yielding and cleaned up after.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
An AsyncClient configured for the app.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi_toolsets.pytest import create_async_client
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def client():
|
||||||
|
async with create_async_client(app) as c:
|
||||||
|
yield c
|
||||||
|
|
||||||
|
async def test_endpoint(client: AsyncClient):
|
||||||
|
response = await client.get("/health")
|
||||||
|
assert response.status_code == 200
|
||||||
|
```
|
||||||
|
|
||||||
|
Example with dependency overrides:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import create_async_client, create_db_session
|
||||||
|
from app.db import get_db
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session():
|
||||||
|
async with create_db_session(DATABASE_URL, Base, cleanup=True) as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def client(db_session):
|
||||||
|
async def override():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
async with create_async_client(
|
||||||
|
app, dependency_overrides={get_db: override}
|
||||||
|
) as c:
|
||||||
|
yield c
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if dependency_overrides:
|
||||||
|
app.dependency_overrides.update(dependency_overrides)
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
try:
|
||||||
|
async with AsyncClient(transport=transport, base_url=base_url) as client:
|
||||||
|
yield client
|
||||||
|
finally:
|
||||||
|
if dependency_overrides:
|
||||||
|
for key in dependency_overrides:
|
||||||
|
app.dependency_overrides.pop(key, None)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def create_db_session(
|
||||||
|
database_url: str,
|
||||||
|
base: type[DeclarativeBase],
|
||||||
|
*,
|
||||||
|
echo: bool = False,
|
||||||
|
expire_on_commit: bool = False,
|
||||||
|
drop_tables: bool = True,
|
||||||
|
cleanup: bool = False,
|
||||||
|
) -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""Create a database session for testing.
|
||||||
|
|
||||||
|
Creates tables before yielding the session and optionally drops them after.
|
||||||
|
Each call creates a fresh engine and session for test isolation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
database_url: Database connection URL (e.g., "postgresql+asyncpg://...").
|
||||||
|
base: SQLAlchemy DeclarativeBase class containing model metadata.
|
||||||
|
echo: Enable SQLAlchemy query logging. Defaults to False.
|
||||||
|
expire_on_commit: Expire objects after commit. Defaults to False.
|
||||||
|
drop_tables: Drop tables after test. Defaults to True.
|
||||||
|
cleanup: Truncate all tables after test using
|
||||||
|
:func:`cleanup_tables`. Defaults to False.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
An AsyncSession ready for database operations.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import create_db_session
|
||||||
|
from app.models import Base
|
||||||
|
|
||||||
|
DATABASE_URL = "postgresql+asyncpg://user:pass@localhost/test_db"
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session():
|
||||||
|
async with create_db_session(
|
||||||
|
DATABASE_URL, Base, cleanup=True
|
||||||
|
) as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
async def test_create_user(db_session: AsyncSession):
|
||||||
|
user = User(name="test")
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
engine = create_async_engine(database_url, echo=echo)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create tables
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(base.metadata.create_all)
|
||||||
|
|
||||||
|
session_maker = async_sessionmaker(
|
||||||
|
engine, expire_on_commit=expire_on_commit, class_=EventSession
|
||||||
|
)
|
||||||
|
async with session_maker() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
if cleanup:
|
||||||
|
await cleanup_tables(session=session, base=base)
|
||||||
|
|
||||||
|
if drop_tables:
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(base.metadata.drop_all)
|
||||||
|
finally:
|
||||||
|
await engine.dispose()
|
||||||
@@ -1,21 +1,27 @@
|
|||||||
"""Base Pydantic schemas for API responses."""
|
"""Base Pydantic schemas for API responses."""
|
||||||
|
|
||||||
|
import math
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import ClassVar, Generic, TypeVar
|
from typing import Annotated, Any, ClassVar, Generic, Literal, TypeVar, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict, Field, computed_field
|
||||||
|
|
||||||
|
from .types import DataT
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ApiError",
|
"ApiError",
|
||||||
|
"CursorPagination",
|
||||||
|
"CursorPaginatedResponse",
|
||||||
"ErrorResponse",
|
"ErrorResponse",
|
||||||
"Pagination",
|
"OffsetPagination",
|
||||||
|
"OffsetPaginatedResponse",
|
||||||
"PaginatedResponse",
|
"PaginatedResponse",
|
||||||
|
"PaginationType",
|
||||||
|
"PydanticBase",
|
||||||
"Response",
|
"Response",
|
||||||
"ResponseStatus",
|
"ResponseStatus",
|
||||||
]
|
]
|
||||||
|
|
||||||
DataT = TypeVar("DataT")
|
|
||||||
|
|
||||||
|
|
||||||
class PydanticBase(BaseModel):
|
class PydanticBase(BaseModel):
|
||||||
"""Base class for all Pydantic models with common configuration."""
|
"""Base class for all Pydantic models with common configuration."""
|
||||||
@@ -49,6 +55,7 @@ class ApiError(PydanticBase):
|
|||||||
msg: str
|
msg: str
|
||||||
desc: str
|
desc: str
|
||||||
err_code: str
|
err_code: str
|
||||||
|
data: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
class BaseResponse(PydanticBase):
|
class BaseResponse(PydanticBase):
|
||||||
@@ -69,7 +76,9 @@ class Response(BaseResponse, Generic[DataT]):
|
|||||||
"""Generic API response with data payload.
|
"""Generic API response with data payload.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
Response[UserRead](data=user, message="User retrieved")
|
Response[UserRead](data=user, message="User retrieved")
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: DataT | None = None
|
data: DataT | None = None
|
||||||
@@ -83,34 +92,114 @@ class ErrorResponse(BaseResponse):
|
|||||||
|
|
||||||
status: ResponseStatus = ResponseStatus.FAIL
|
status: ResponseStatus = ResponseStatus.FAIL
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
data: None = None
|
data: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
class Pagination(PydanticBase):
|
class OffsetPagination(PydanticBase):
|
||||||
"""Pagination metadata for list responses.
|
"""Pagination metadata for offset-based list responses.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
total_count: Total number of items across all pages
|
total_count: Total number of items across all pages.
|
||||||
|
``None`` when ``include_total=False``.
|
||||||
items_per_page: Number of items per page
|
items_per_page: Number of items per page
|
||||||
page: Current page number (1-indexed)
|
page: Current page number (1-indexed)
|
||||||
has_more: Whether there are more pages
|
has_more: Whether there are more pages
|
||||||
|
pages: Total number of pages
|
||||||
"""
|
"""
|
||||||
|
|
||||||
total_count: int
|
total_count: int | None
|
||||||
items_per_page: int
|
items_per_page: int
|
||||||
page: int
|
page: int
|
||||||
has_more: bool
|
has_more: bool
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def pages(self) -> int | None:
|
||||||
|
"""Total number of pages, or ``None`` when ``total_count`` is unknown."""
|
||||||
|
if self.total_count is None:
|
||||||
|
return None
|
||||||
|
if self.items_per_page == 0:
|
||||||
|
return 0
|
||||||
|
return math.ceil(self.total_count / self.items_per_page)
|
||||||
|
|
||||||
|
|
||||||
|
class CursorPagination(PydanticBase):
|
||||||
|
"""Pagination metadata for cursor-based list responses.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
next_cursor: Encoded cursor for the next page, or None on the last page.
|
||||||
|
prev_cursor: Encoded cursor for the previous page, or None on the first page.
|
||||||
|
items_per_page: Number of items requested per page.
|
||||||
|
has_more: Whether there is at least one more page after this one.
|
||||||
|
"""
|
||||||
|
|
||||||
|
next_cursor: str | None
|
||||||
|
prev_cursor: str | None = None
|
||||||
|
items_per_page: int
|
||||||
|
has_more: bool
|
||||||
|
|
||||||
|
|
||||||
|
class PaginationType(str, Enum):
|
||||||
|
"""Pagination strategy selector for :meth:`.AsyncCrud.paginate`."""
|
||||||
|
|
||||||
|
OFFSET = "offset"
|
||||||
|
CURSOR = "cursor"
|
||||||
|
|
||||||
|
|
||||||
class PaginatedResponse(BaseResponse, Generic[DataT]):
|
class PaginatedResponse(BaseResponse, Generic[DataT]):
|
||||||
"""Paginated API response for list endpoints.
|
"""Paginated API response for list endpoints.
|
||||||
|
|
||||||
Example:
|
Base class and return type for endpoints that support both pagination
|
||||||
PaginatedResponse[UserRead](
|
strategies. Use :class:`OffsetPaginatedResponse` or
|
||||||
data=users,
|
:class:`CursorPaginatedResponse` when the strategy is fixed.
|
||||||
pagination=Pagination(total_count=100, items_per_page=10, page=1, has_more=True)
|
|
||||||
)
|
When used as ``PaginatedResponse[T]`` in a return annotation, subscripting
|
||||||
|
returns ``Annotated[Union[CursorPaginatedResponse[T], OffsetPaginatedResponse[T]], Field(discriminator="pagination_type")]``
|
||||||
|
so FastAPI emits a proper ``oneOf`` + discriminator in the OpenAPI schema.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: list[DataT]
|
data: list[DataT]
|
||||||
pagination: Pagination
|
pagination: OffsetPagination | CursorPagination
|
||||||
|
pagination_type: PaginationType | None = None
|
||||||
|
filter_attributes: dict[str, list[Any]] | None = None
|
||||||
|
search_columns: list[str] | None = None
|
||||||
|
|
||||||
|
_discriminated_union_cache: ClassVar[dict[Any, Any]] = {}
|
||||||
|
|
||||||
|
def __class_getitem__( # ty:ignore[invalid-method-override]
|
||||||
|
cls, item: type[Any] | tuple[type[Any], ...]
|
||||||
|
) -> type[Any]:
|
||||||
|
if cls is PaginatedResponse and not isinstance(item, TypeVar):
|
||||||
|
cached = cls._discriminated_union_cache.get(item)
|
||||||
|
if cached is None:
|
||||||
|
cached = Annotated[
|
||||||
|
Union[CursorPaginatedResponse[item], OffsetPaginatedResponse[item]], # ty:ignore[invalid-type-form]
|
||||||
|
Field(discriminator="pagination_type"),
|
||||||
|
]
|
||||||
|
cls._discriminated_union_cache[item] = cached
|
||||||
|
return cached # ty:ignore[invalid-return-type]
|
||||||
|
return super().__class_getitem__(item)
|
||||||
|
|
||||||
|
|
||||||
|
class OffsetPaginatedResponse(PaginatedResponse[DataT]):
|
||||||
|
"""Paginated response with typed offset-based pagination metadata.
|
||||||
|
|
||||||
|
The ``pagination_type`` field is always ``"offset"`` and acts as a
|
||||||
|
discriminator, allowing frontend clients to narrow the union type returned
|
||||||
|
by a unified ``paginate()`` endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pagination: OffsetPagination
|
||||||
|
pagination_type: Literal[PaginationType.OFFSET] = PaginationType.OFFSET
|
||||||
|
|
||||||
|
|
||||||
|
class CursorPaginatedResponse(PaginatedResponse[DataT]):
|
||||||
|
"""Paginated response with typed cursor-based pagination metadata.
|
||||||
|
|
||||||
|
The ``pagination_type`` field is always ``"cursor"`` and acts as a
|
||||||
|
discriminator, allowing frontend clients to narrow the union type returned
|
||||||
|
by a unified ``paginate()`` endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pagination: CursorPagination
|
||||||
|
pagination_type: Literal[PaginationType.CURSOR] = PaginationType.CURSOR
|
||||||
|
|||||||
27
src/fastapi_toolsets/types.py
Normal file
27
src/fastapi_toolsets/types.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
"""Shared type aliases for the fastapi-toolsets package."""
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator, Callable, Mapping
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute
|
||||||
|
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||||
|
from sqlalchemy.sql.elements import ColumnElement
|
||||||
|
|
||||||
|
# Generic TypeVars
|
||||||
|
DataT = TypeVar("DataT")
|
||||||
|
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
||||||
|
SchemaType = TypeVar("SchemaType", bound=BaseModel)
|
||||||
|
|
||||||
|
# CRUD type aliases
|
||||||
|
JoinType = list[tuple[type[DeclarativeBase], Any]]
|
||||||
|
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
|
||||||
|
OrderByClause = ColumnElement[Any] | QueryableAttribute[Any]
|
||||||
|
|
||||||
|
# Search / facet type aliases
|
||||||
|
SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...]
|
||||||
|
FacetFieldType = SearchFieldType
|
||||||
|
|
||||||
|
# Dependency type aliases
|
||||||
|
SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]] | Any
|
||||||
@@ -1,27 +1,38 @@
|
|||||||
"""Shared pytest fixtures for fastapi-utils tests."""
|
"""Shared pytest fixtures for fastapi-utils tests."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import ForeignKey, String
|
import datetime
|
||||||
|
import decimal
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
Column,
|
||||||
|
Date,
|
||||||
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
Integer,
|
||||||
|
JSON,
|
||||||
|
Numeric,
|
||||||
|
String,
|
||||||
|
Table,
|
||||||
|
Uuid,
|
||||||
|
)
|
||||||
|
from sqlalchemy.dialects.postgresql import ARRAY
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from fastapi_toolsets.crud import CrudFactory
|
from fastapi_toolsets.crud import CrudFactory
|
||||||
|
from fastapi_toolsets.schemas import PydanticBase
|
||||||
|
|
||||||
# PostgreSQL connection URL from environment or default for local development
|
DATABASE_URL = os.getenv(
|
||||||
DATABASE_URL = os.getenv("DATABASE_URL") or os.getenv(
|
key="DATABASE_URL",
|
||||||
"TEST_DATABASE_URL",
|
default="postgresql+asyncpg://postgres:postgres@localhost:5432/postgres",
|
||||||
"postgresql+asyncpg://postgres:postgres@localhost:5432/fastapi_toolsets_test",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Test Models
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class Base(DeclarativeBase):
|
class Base(DeclarativeBase):
|
||||||
"""Base class for test models."""
|
"""Base class for test models."""
|
||||||
|
|
||||||
@@ -33,7 +44,7 @@ class Role(Base):
|
|||||||
|
|
||||||
__tablename__ = "roles"
|
__tablename__ = "roles"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True)
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
name: Mapped[str] = mapped_column(String(50), unique=True)
|
name: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
|
|
||||||
users: Mapped[list["User"]] = relationship(back_populates="role")
|
users: Mapped[list["User"]] = relationship(back_populates="role")
|
||||||
@@ -44,36 +55,112 @@ class User(Base):
|
|||||||
|
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True)
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
username: Mapped[str] = mapped_column(String(50), unique=True)
|
username: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
email: Mapped[str] = mapped_column(String(100), unique=True)
|
email: Mapped[str] = mapped_column(String(100), unique=True)
|
||||||
is_active: Mapped[bool] = mapped_column(default=True)
|
is_active: Mapped[bool] = mapped_column(default=True)
|
||||||
role_id: Mapped[int | None] = mapped_column(ForeignKey("roles.id"), nullable=True)
|
notes: Mapped[str | None]
|
||||||
|
role_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
|
ForeignKey("roles.id"), nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
role: Mapped[Role | None] = relationship(back_populates="users")
|
role: Mapped[Role | None] = relationship(back_populates="users")
|
||||||
|
|
||||||
|
|
||||||
|
class Tag(Base):
|
||||||
|
"""Test tag model."""
|
||||||
|
|
||||||
|
__tablename__ = "tags"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
|
name: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
|
|
||||||
|
|
||||||
|
post_tags = Table(
|
||||||
|
"post_tags",
|
||||||
|
Base.metadata,
|
||||||
|
Column(
|
||||||
|
"post_id", Uuid, ForeignKey("posts.id", ondelete="CASCADE"), primary_key=True
|
||||||
|
),
|
||||||
|
Column("tag_id", Uuid, ForeignKey("tags.id", ondelete="CASCADE"), primary_key=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IntRole(Base):
|
||||||
|
"""Test role model with auto-increment integer PK."""
|
||||||
|
|
||||||
|
__tablename__ = "int_roles"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
name: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
|
|
||||||
|
|
||||||
|
class Permission(Base):
|
||||||
|
"""Test model with composite primary key."""
|
||||||
|
|
||||||
|
__tablename__ = "permissions"
|
||||||
|
|
||||||
|
subject: Mapped[str] = mapped_column(String(50), primary_key=True)
|
||||||
|
action: Mapped[str] = mapped_column(String(50), primary_key=True)
|
||||||
|
|
||||||
|
|
||||||
|
class Event(Base):
|
||||||
|
"""Test model with DateTime and Date cursor columns."""
|
||||||
|
|
||||||
|
__tablename__ = "events"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
|
name: Mapped[str] = mapped_column(String(100))
|
||||||
|
occurred_at: Mapped[datetime.datetime] = mapped_column(DateTime)
|
||||||
|
scheduled_date: Mapped[datetime.date] = mapped_column(Date)
|
||||||
|
|
||||||
|
|
||||||
|
class Product(Base):
|
||||||
|
"""Test model with Numeric cursor column."""
|
||||||
|
|
||||||
|
__tablename__ = "products"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
|
name: Mapped[str] = mapped_column(String(100))
|
||||||
|
price: Mapped[decimal.Decimal] = mapped_column(Numeric(10, 2))
|
||||||
|
|
||||||
|
|
||||||
class Post(Base):
|
class Post(Base):
|
||||||
"""Test post model."""
|
"""Test post model."""
|
||||||
|
|
||||||
__tablename__ = "posts"
|
__tablename__ = "posts"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True)
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
title: Mapped[str] = mapped_column(String(200))
|
title: Mapped[str] = mapped_column(String(200))
|
||||||
content: Mapped[str] = mapped_column(String(1000), default="")
|
content: Mapped[str] = mapped_column(String(1000), default="")
|
||||||
is_published: Mapped[bool] = mapped_column(default=False)
|
is_published: Mapped[bool] = mapped_column(default=False)
|
||||||
author_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
|
author_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
|
||||||
|
|
||||||
|
tags: Mapped[list[Tag]] = relationship(secondary=post_tags)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
class Article(Base):
|
||||||
# Test Schemas
|
"""Test article model with ARRAY and JSON columns."""
|
||||||
# =============================================================================
|
|
||||||
|
__tablename__ = "articles"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
|
title: Mapped[str] = mapped_column(String(200))
|
||||||
|
labels: Mapped[list[str]] = mapped_column(ARRAY(String))
|
||||||
|
metadata_: Mapped[dict | None] = mapped_column("metadata", JSON, nullable=True)
|
||||||
|
|
||||||
|
|
||||||
class RoleCreate(BaseModel):
|
class RoleCreate(BaseModel):
|
||||||
"""Schema for creating a role."""
|
"""Schema for creating a role."""
|
||||||
|
|
||||||
id: int | None = None
|
id: uuid.UUID | None = None
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class RoleRead(PydanticBase):
|
||||||
|
"""Schema for reading a role."""
|
||||||
|
|
||||||
|
id: uuid.UUID
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
@@ -86,11 +173,19 @@ class RoleUpdate(BaseModel):
|
|||||||
class UserCreate(BaseModel):
|
class UserCreate(BaseModel):
|
||||||
"""Schema for creating a user."""
|
"""Schema for creating a user."""
|
||||||
|
|
||||||
id: int | None = None
|
id: uuid.UUID | None = None
|
||||||
username: str
|
username: str
|
||||||
email: str
|
email: str
|
||||||
is_active: bool = True
|
is_active: bool = True
|
||||||
role_id: int | None = None
|
role_id: uuid.UUID | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class UserRead(PydanticBase):
|
||||||
|
"""Schema for reading a user (subset of fields — no email)."""
|
||||||
|
|
||||||
|
id: uuid.UUID
|
||||||
|
username: str
|
||||||
|
is_active: bool = True
|
||||||
|
|
||||||
|
|
||||||
class UserUpdate(BaseModel):
|
class UserUpdate(BaseModel):
|
||||||
@@ -99,17 +194,24 @@ class UserUpdate(BaseModel):
|
|||||||
username: str | None = None
|
username: str | None = None
|
||||||
email: str | None = None
|
email: str | None = None
|
||||||
is_active: bool | None = None
|
is_active: bool | None = None
|
||||||
role_id: int | None = None
|
role_id: uuid.UUID | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TagCreate(BaseModel):
|
||||||
|
"""Schema for creating a tag."""
|
||||||
|
|
||||||
|
id: uuid.UUID | None = None
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
class PostCreate(BaseModel):
|
class PostCreate(BaseModel):
|
||||||
"""Schema for creating a post."""
|
"""Schema for creating a post."""
|
||||||
|
|
||||||
id: int | None = None
|
id: uuid.UUID | None = None
|
||||||
title: str
|
title: str
|
||||||
content: str = ""
|
content: str = ""
|
||||||
is_published: bool = False
|
is_published: bool = False
|
||||||
author_id: int
|
author_id: uuid.UUID
|
||||||
|
|
||||||
|
|
||||||
class PostUpdate(BaseModel):
|
class PostUpdate(BaseModel):
|
||||||
@@ -120,18 +222,98 @@ class PostUpdate(BaseModel):
|
|||||||
is_published: bool | None = None
|
is_published: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
class PostM2MCreate(BaseModel):
|
||||||
# CRUD Classes
|
"""Schema for creating a post with M2M tag IDs."""
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
id: uuid.UUID | None = None
|
||||||
|
title: str
|
||||||
|
content: str = ""
|
||||||
|
is_published: bool = False
|
||||||
|
author_id: uuid.UUID
|
||||||
|
tag_ids: list[uuid.UUID] = []
|
||||||
|
|
||||||
|
|
||||||
|
class PostM2MUpdate(BaseModel):
|
||||||
|
"""Schema for updating a post with M2M tag IDs."""
|
||||||
|
|
||||||
|
title: str | None = None
|
||||||
|
content: str | None = None
|
||||||
|
is_published: bool | None = None
|
||||||
|
tag_ids: list[uuid.UUID] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class IntRoleRead(PydanticBase):
|
||||||
|
"""Schema for reading an IntRole."""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class IntRoleCreate(BaseModel):
|
||||||
|
"""Schema for creating an IntRole."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class EventRead(PydanticBase):
|
||||||
|
"""Schema for reading an Event."""
|
||||||
|
|
||||||
|
id: uuid.UUID
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class EventCreate(BaseModel):
|
||||||
|
"""Schema for creating an Event."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
occurred_at: datetime.datetime
|
||||||
|
scheduled_date: datetime.date
|
||||||
|
|
||||||
|
|
||||||
|
class ProductRead(PydanticBase):
|
||||||
|
"""Schema for reading a Product."""
|
||||||
|
|
||||||
|
id: uuid.UUID
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class ProductCreate(BaseModel):
|
||||||
|
"""Schema for creating a Product."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
price: decimal.Decimal
|
||||||
|
|
||||||
|
|
||||||
|
class ArticleCreate(BaseModel):
|
||||||
|
"""Schema for creating an article."""
|
||||||
|
|
||||||
|
id: uuid.UUID | None = None
|
||||||
|
title: str
|
||||||
|
labels: list[str] = []
|
||||||
|
|
||||||
|
|
||||||
|
class ArticleRead(PydanticBase):
|
||||||
|
"""Schema for reading an article."""
|
||||||
|
|
||||||
|
id: uuid.UUID
|
||||||
|
title: str
|
||||||
|
labels: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
ArticleCrud = CrudFactory(Article)
|
||||||
RoleCrud = CrudFactory(Role)
|
RoleCrud = CrudFactory(Role)
|
||||||
|
RoleCursorCrud = CrudFactory(Role, cursor_column=Role.id)
|
||||||
|
IntRoleCursorCrud = CrudFactory(IntRole, cursor_column=IntRole.id)
|
||||||
UserCrud = CrudFactory(User)
|
UserCrud = CrudFactory(User)
|
||||||
|
UserCursorCrud = CrudFactory(User, cursor_column=User.id)
|
||||||
PostCrud = CrudFactory(Post)
|
PostCrud = CrudFactory(Post)
|
||||||
|
TagCrud = CrudFactory(Tag)
|
||||||
|
PostM2MCrud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags})
|
||||||
# =============================================================================
|
EventCrud = CrudFactory(Event)
|
||||||
# Fixtures
|
EventDateTimeCursorCrud = CrudFactory(Event, cursor_column=Event.occurred_at)
|
||||||
# =============================================================================
|
EventDateCursorCrud = CrudFactory(Event, cursor_column=Event.scheduled_date)
|
||||||
|
ProductCrud = CrudFactory(Product)
|
||||||
|
ProductNumericCursorCrud = CrudFactory(Product, cursor_column=Product.price)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -170,30 +352,3 @@ async def db_session(engine):
|
|||||||
# Drop tables after test
|
# Drop tables after test
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.run_sync(Base.metadata.drop_all)
|
await conn.run_sync(Base.metadata.drop_all)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_role_data() -> RoleCreate:
|
|
||||||
"""Sample role creation data."""
|
|
||||||
return RoleCreate(name="admin")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_user_data() -> UserCreate:
|
|
||||||
"""Sample user creation data."""
|
|
||||||
return UserCreate(
|
|
||||||
username="testuser",
|
|
||||||
email="test@example.com",
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_post_data() -> PostCreate:
|
|
||||||
"""Sample post creation data."""
|
|
||||||
return PostCreate(
|
|
||||||
title="Test Post",
|
|
||||||
content="Test content",
|
|
||||||
is_published=True,
|
|
||||||
author_id=1,
|
|
||||||
)
|
|
||||||
|
|||||||
543
tests/test_cli.py
Normal file
543
tests/test_cli.py
Normal file
@@ -0,0 +1,543 @@
|
|||||||
|
"""Tests for fastapi_toolsets.cli module."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli.config import (
|
||||||
|
get_config_value,
|
||||||
|
get_custom_cli,
|
||||||
|
get_db_context,
|
||||||
|
get_fixtures_registry,
|
||||||
|
import_from_string,
|
||||||
|
)
|
||||||
|
from fastapi_toolsets.cli.pyproject import find_pyproject, load_pyproject
|
||||||
|
from fastapi_toolsets.cli.utils import async_command
|
||||||
|
from fastapi_toolsets.fixtures import FixtureRegistry
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
|
||||||
|
|
||||||
|
class TestPyproject:
|
||||||
|
"""Tests for pyproject.toml discovery and loading."""
|
||||||
|
|
||||||
|
def test_find_pyproject_in_current_dir(self, tmp_path, monkeypatch):
|
||||||
|
"""Finds pyproject.toml in current directory."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[project]\nname = 'test'\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = find_pyproject()
|
||||||
|
assert result == pyproject
|
||||||
|
|
||||||
|
def test_find_pyproject_in_parent_dir(self, tmp_path, monkeypatch):
|
||||||
|
"""Finds pyproject.toml in parent directory."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[project]\nname = 'test'\n")
|
||||||
|
subdir = tmp_path / "src" / "app"
|
||||||
|
subdir.mkdir(parents=True)
|
||||||
|
monkeypatch.chdir(subdir)
|
||||||
|
|
||||||
|
result = find_pyproject()
|
||||||
|
assert result == pyproject
|
||||||
|
|
||||||
|
def test_find_pyproject_not_found(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns None when no pyproject.toml exists."""
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
result = find_pyproject()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_load_pyproject_returns_tool_config(self, tmp_path, monkeypatch):
|
||||||
|
"""load_pyproject returns the [tool.fastapi-toolsets] section."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text(
|
||||||
|
'[tool.fastapi-toolsets]\nfixtures = "app.fixtures:registry"\n'
|
||||||
|
)
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = load_pyproject()
|
||||||
|
assert result == {"fixtures": "app.fixtures:registry"}
|
||||||
|
|
||||||
|
def test_load_pyproject_empty_when_no_file(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns empty dict when no pyproject.toml exists."""
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
result = load_pyproject()
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
def test_load_pyproject_empty_when_no_tool_section(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns empty dict when no [tool.fastapi-toolsets] section."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[project]\nname = 'test'\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = load_pyproject()
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
def test_load_pyproject_invalid_toml(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns empty dict when pyproject.toml is invalid."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("invalid toml {{{")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = load_pyproject()
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestImportFromString:
|
||||||
|
"""Tests for import_from_string function."""
|
||||||
|
|
||||||
|
def test_import_valid_path(self):
|
||||||
|
"""Import valid module:attribute path."""
|
||||||
|
result = import_from_string("fastapi_toolsets.fixtures:FixtureRegistry")
|
||||||
|
assert result is FixtureRegistry
|
||||||
|
|
||||||
|
def test_import_without_colon_raises_error(self):
|
||||||
|
"""Import path without colon raises error."""
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
import_from_string("fastapi_toolsets.fixtures.FixtureRegistry")
|
||||||
|
assert "Expected format: 'module:attribute'" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_import_nonexistent_module_raises_error(self):
|
||||||
|
"""Import nonexistent module raises error."""
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
import_from_string("nonexistent.module:something")
|
||||||
|
assert "Cannot import module" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_import_nonexistent_attribute_raises_error(self):
|
||||||
|
"""Import nonexistent attribute raises error."""
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
import_from_string("fastapi_toolsets.fixtures:NonexistentClass")
|
||||||
|
assert "has no attribute" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetConfigValue:
|
||||||
|
"""Tests for get_config_value function."""
|
||||||
|
|
||||||
|
def test_get_existing_value(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns value when key exists."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text('[tool.fastapi-toolsets]\nfixtures = "app:registry"\n')
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = get_config_value("fixtures")
|
||||||
|
assert result == "app:registry"
|
||||||
|
|
||||||
|
def test_get_missing_value_returns_none(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns None when key is missing and not required."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = get_config_value("fixtures")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_get_missing_value_required_raises_error(self, tmp_path, monkeypatch):
|
||||||
|
"""Raises error when key is missing and required."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
get_config_value("fixtures", required=True)
|
||||||
|
assert "No 'fixtures' configured" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetFixturesRegistry:
|
||||||
|
"""Tests for get_fixtures_registry function."""
|
||||||
|
|
||||||
|
def test_raises_when_not_configured(self, tmp_path, monkeypatch):
|
||||||
|
"""Raises error when fixtures not configured."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
get_fixtures_registry()
|
||||||
|
assert "No 'fixtures' configured" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_raises_when_not_registry_instance(self, tmp_path, monkeypatch):
|
||||||
|
"""Raises error when imported object is not a FixtureRegistry."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text(
|
||||||
|
'[tool.fastapi-toolsets]\nfixtures = "my_fixtures:registry"\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
fixtures_file = tmp_path / "my_fixtures.py"
|
||||||
|
fixtures_file.write_text("registry = 'not a registry'\n")
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
if str(tmp_path) not in sys.path:
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
get_fixtures_registry()
|
||||||
|
assert "must be a FixtureRegistry instance" in str(exc_info.value)
|
||||||
|
finally:
|
||||||
|
if str(tmp_path) in sys.path:
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
if "my_fixtures" in sys.modules:
|
||||||
|
del sys.modules["my_fixtures"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetDbContext:
|
||||||
|
"""Tests for get_db_context function."""
|
||||||
|
|
||||||
|
def test_raises_when_not_configured(self, tmp_path, monkeypatch):
|
||||||
|
"""Raises error when db_context not configured."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
get_db_context()
|
||||||
|
assert "No 'db_context' configured" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetCustomCli:
|
||||||
|
"""Tests for get_custom_cli function."""
|
||||||
|
|
||||||
|
def test_returns_none_when_not_configured(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns None when custom_cli not configured."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = get_custom_cli()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_raises_when_not_typer_instance(self, tmp_path, monkeypatch):
|
||||||
|
"""Raises error when imported object is not a Typer instance."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text('[tool.fastapi-toolsets]\ncustom_cli = "my_cli:cli"\n')
|
||||||
|
|
||||||
|
cli_file = tmp_path / "my_cli.py"
|
||||||
|
cli_file.write_text("cli = 'not a typer'\n")
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
if str(tmp_path) not in sys.path:
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
get_custom_cli()
|
||||||
|
assert "must be a Typer instance" in str(exc_info.value)
|
||||||
|
finally:
|
||||||
|
if str(tmp_path) in sys.path:
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
if "my_cli" in sys.modules:
|
||||||
|
del sys.modules["my_cli"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestCliApp:
|
||||||
|
"""Tests for CLI application."""
|
||||||
|
|
||||||
|
def test_cli_help(self, tmp_path, monkeypatch):
|
||||||
|
"""CLI shows help without fixtures."""
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
# Need to reload the module to pick up new cwd
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app
|
||||||
|
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
result = runner.invoke(app.cli, ["--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "CLI utilities for FastAPI projects" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
class TestFixturesCli:
|
||||||
|
"""Tests for fixtures CLI commands."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def cli_env(self, tmp_path, monkeypatch):
|
||||||
|
"""Set up CLI environment with fixtures config."""
|
||||||
|
# Create pyproject.toml
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text(
|
||||||
|
"[tool.fastapi-toolsets]\n"
|
||||||
|
'fixtures = "fixtures:registry"\n'
|
||||||
|
'db_context = "db:get_session"\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create fixtures module
|
||||||
|
fixtures_file = tmp_path / "fixtures.py"
|
||||||
|
fixtures_file.write_text(
|
||||||
|
"from fastapi_toolsets.fixtures import FixtureRegistry, Context\n"
|
||||||
|
"\n"
|
||||||
|
"registry = FixtureRegistry()\n"
|
||||||
|
"\n"
|
||||||
|
"@registry.register(contexts=[Context.BASE])\n"
|
||||||
|
"def roles():\n"
|
||||||
|
' return [{"id": 1, "name": "admin"}, {"id": 2, "name": "user"}]\n'
|
||||||
|
"\n"
|
||||||
|
'@registry.register(depends_on=["roles"], contexts=[Context.TESTING])\n'
|
||||||
|
"def users():\n"
|
||||||
|
' return [{"id": 1, "name": "alice", "role_id": 1}]\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create db module
|
||||||
|
db_file = tmp_path / "db.py"
|
||||||
|
db_file.write_text(
|
||||||
|
"from contextlib import asynccontextmanager\n"
|
||||||
|
"\n"
|
||||||
|
"@asynccontextmanager\n"
|
||||||
|
"async def get_session():\n"
|
||||||
|
" yield None\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
# Add tmp_path to sys.path for imports
|
||||||
|
if str(tmp_path) not in sys.path:
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
# Reload the CLI module to pick up new config
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app
|
||||||
|
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
yield tmp_path, app.cli
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
if str(tmp_path) in sys.path:
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
|
||||||
|
def test_fixtures_list(self, cli_env):
|
||||||
|
"""fixtures list shows registered fixtures."""
|
||||||
|
tmp_path, cli = cli_env
|
||||||
|
result = runner.invoke(cli, ["fixtures", "list"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "roles" in result.output
|
||||||
|
assert "users" in result.output
|
||||||
|
assert "Total: 2 fixture(s)" in result.output
|
||||||
|
|
||||||
|
def test_fixtures_list_with_context(self, cli_env):
|
||||||
|
"""fixtures list --context filters by context."""
|
||||||
|
tmp_path, cli = cli_env
|
||||||
|
result = runner.invoke(cli, ["fixtures", "list", "--context", "base"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "roles" in result.output
|
||||||
|
assert "users" not in result.output
|
||||||
|
assert "Total: 1 fixture(s)" in result.output
|
||||||
|
|
||||||
|
def test_fixtures_load_dry_run(self, cli_env):
|
||||||
|
"""fixtures load --dry-run shows what would be loaded."""
|
||||||
|
tmp_path, cli = cli_env
|
||||||
|
result = runner.invoke(cli, ["fixtures", "load", "base", "--dry-run"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Fixtures to load" in result.output
|
||||||
|
assert "roles" in result.output
|
||||||
|
assert "[Dry run - no changes made]" in result.output
|
||||||
|
|
||||||
|
def test_fixtures_load_invalid_strategy(self, cli_env):
|
||||||
|
"""fixtures load with invalid strategy shows error."""
|
||||||
|
tmp_path, cli = cli_env
|
||||||
|
result = runner.invoke(
|
||||||
|
cli, ["fixtures", "load", "base", "--strategy", "invalid"]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code != 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestCliWithoutFixturesConfig:
|
||||||
|
"""Tests for CLI when fixtures is not configured."""
|
||||||
|
|
||||||
|
def test_no_fixtures_command(self, tmp_path, monkeypatch):
|
||||||
|
"""fixtures command is not available when not configured."""
|
||||||
|
# Create pyproject.toml without fixtures
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text('[project]\nname = "test"\n')
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
# Reload the CLI module
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app
|
||||||
|
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
result = runner.invoke(app.cli, ["--help"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "fixtures" not in result.output
|
||||||
|
|
||||||
|
|
||||||
|
class TestCustomCliConfig:
|
||||||
|
"""Tests for custom CLI configuration."""
|
||||||
|
|
||||||
|
def test_cli_with_custom_cli(self, tmp_path, monkeypatch):
|
||||||
|
"""CLI uses custom Typer instance when configured."""
|
||||||
|
import typer
|
||||||
|
|
||||||
|
# Create pyproject.toml with custom_cli config
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text('[tool.fastapi-toolsets]\ncustom_cli = "my_cli:cli"\n')
|
||||||
|
|
||||||
|
# Create custom CLI module with its own Typer and commands
|
||||||
|
cli_file = tmp_path / "my_cli.py"
|
||||||
|
cli_file.write_text(
|
||||||
|
"import typer\n"
|
||||||
|
"\n"
|
||||||
|
"cli = typer.Typer(name='my-app', help='My custom CLI')\n"
|
||||||
|
"\n"
|
||||||
|
"@cli.command()\n"
|
||||||
|
"def hello():\n"
|
||||||
|
' print("Hello from custom CLI!")\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
# Add tmp_path to sys.path for imports
|
||||||
|
if str(tmp_path) not in sys.path:
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
# Remove my_cli from sys.modules if it was previously loaded
|
||||||
|
if "my_cli" in sys.modules:
|
||||||
|
del sys.modules["my_cli"]
|
||||||
|
|
||||||
|
# Reload the CLI module to pick up new config
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app
|
||||||
|
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Verify custom CLI is used
|
||||||
|
assert isinstance(app.cli, typer.Typer)
|
||||||
|
|
||||||
|
result = runner.invoke(app.cli, ["--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "My custom CLI" in result.output
|
||||||
|
assert "hello" in result.output
|
||||||
|
|
||||||
|
result = runner.invoke(app.cli, ["hello"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Hello from custom CLI!" in result.output
|
||||||
|
finally:
|
||||||
|
if str(tmp_path) in sys.path:
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
if "my_cli" in sys.modules:
|
||||||
|
del sys.modules["my_cli"]
|
||||||
|
|
||||||
|
def test_custom_cli_with_fixtures(self, tmp_path, monkeypatch):
|
||||||
|
"""Custom CLI gets fixtures command added when configured."""
|
||||||
|
# Create pyproject.toml with both custom_cli and fixtures
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text(
|
||||||
|
"[tool.fastapi-toolsets]\n"
|
||||||
|
'custom_cli = "my_cli:cli"\n'
|
||||||
|
'fixtures = "fixtures:registry"\n'
|
||||||
|
'db_context = "db:get_session"\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create custom CLI module
|
||||||
|
cli_file = tmp_path / "my_cli.py"
|
||||||
|
cli_file.write_text(
|
||||||
|
"import typer\n"
|
||||||
|
"\n"
|
||||||
|
"cli = typer.Typer(name='my-app', help='My custom CLI')\n"
|
||||||
|
"\n"
|
||||||
|
"@cli.command()\n"
|
||||||
|
"def hello():\n"
|
||||||
|
' print("Hello!")\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create fixtures module
|
||||||
|
fixtures_file = tmp_path / "fixtures.py"
|
||||||
|
fixtures_file.write_text(
|
||||||
|
"from fastapi_toolsets.fixtures import FixtureRegistry\n"
|
||||||
|
"\n"
|
||||||
|
"registry = FixtureRegistry()\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create db module
|
||||||
|
db_file = tmp_path / "db.py"
|
||||||
|
db_file.write_text(
|
||||||
|
"from contextlib import asynccontextmanager\n"
|
||||||
|
"\n"
|
||||||
|
"@asynccontextmanager\n"
|
||||||
|
"async def get_session():\n"
|
||||||
|
" yield None\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
if str(tmp_path) not in sys.path:
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
for mod in ["my_cli", "fixtures", "db"]:
|
||||||
|
if mod in sys.modules:
|
||||||
|
del sys.modules[mod]
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app
|
||||||
|
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = runner.invoke(app.cli, ["--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
# Should have both custom command and fixtures
|
||||||
|
assert "hello" in result.output
|
||||||
|
assert "fixtures" in result.output
|
||||||
|
finally:
|
||||||
|
if str(tmp_path) in sys.path:
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
for mod in ["my_cli", "fixtures", "db"]:
|
||||||
|
if mod in sys.modules:
|
||||||
|
del sys.modules[mod]
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncCommand:
|
||||||
|
"""Tests for async_command decorator."""
|
||||||
|
|
||||||
|
def test_async_command_runs_coroutine(self):
|
||||||
|
"""async_command runs async function synchronously."""
|
||||||
|
|
||||||
|
@async_command
|
||||||
|
async def async_func(value: int) -> int:
|
||||||
|
return value * 2
|
||||||
|
|
||||||
|
result = async_func(21)
|
||||||
|
assert result == 42
|
||||||
|
|
||||||
|
def test_async_command_preserves_signature(self):
|
||||||
|
"""async_command preserves function signature."""
|
||||||
|
|
||||||
|
@async_command
|
||||||
|
async def async_func(name: str, count: int = 1) -> str:
|
||||||
|
return f"{name} x {count}"
|
||||||
|
|
||||||
|
result = async_func("test", count=3)
|
||||||
|
assert result == "test x 3"
|
||||||
|
|
||||||
|
def test_async_command_preserves_docstring(self):
|
||||||
|
"""async_command preserves function docstring."""
|
||||||
|
|
||||||
|
@async_command
|
||||||
|
async def async_func() -> None:
|
||||||
|
"""This is a docstring."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert async_func.__doc__ == """This is a docstring."""
|
||||||
|
|
||||||
|
def test_async_command_preserves_name(self):
|
||||||
|
"""async_command preserves function name."""
|
||||||
|
|
||||||
|
@async_command
|
||||||
|
async def my_async_function() -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert my_async_function.__name__ == "my_async_function"
|
||||||
2191
tests/test_crud.py
2191
tests/test_crud.py
File diff suppressed because it is too large
Load Diff
1963
tests/test_crud_search.py
Normal file
1963
tests/test_crud_search.py
Normal file
File diff suppressed because it is too large
Load Diff
241
tests/test_db.py
241
tests/test_db.py
@@ -1,17 +1,28 @@
|
|||||||
"""Tests for fastapi_toolsets.db module."""
|
"""Tests for fastapi_toolsets.db module."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.engine import make_url
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
from fastapi_toolsets.db import (
|
from fastapi_toolsets.db import (
|
||||||
LockMode,
|
LockMode,
|
||||||
|
cleanup_tables,
|
||||||
|
create_database,
|
||||||
create_db_context,
|
create_db_context,
|
||||||
create_db_dependency,
|
create_db_dependency,
|
||||||
get_transaction,
|
get_transaction,
|
||||||
lock_tables,
|
lock_tables,
|
||||||
|
wait_for_row_change,
|
||||||
)
|
)
|
||||||
|
from fastapi_toolsets.exceptions import NotFoundError
|
||||||
|
from fastapi_toolsets.pytest import create_db_session
|
||||||
|
|
||||||
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User
|
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
|
||||||
|
|
||||||
|
|
||||||
class TestCreateDbDependency:
|
class TestCreateDbDependency:
|
||||||
@@ -57,6 +68,55 @@ class TestCreateDbDependency:
|
|||||||
await conn.run_sync(Base.metadata.drop_all)
|
await conn.run_sync(Base.metadata.drop_all)
|
||||||
await engine.dispose()
|
await engine.dispose()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_in_transaction_on_yield(self):
|
||||||
|
"""Session is already in a transaction when the endpoint body starts."""
|
||||||
|
engine = create_async_engine(DATABASE_URL, echo=False)
|
||||||
|
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
get_db = create_db_dependency(session_factory)
|
||||||
|
|
||||||
|
async for session in get_db():
|
||||||
|
assert session.in_transaction()
|
||||||
|
break
|
||||||
|
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_update_after_lock_tables_is_persisted(self):
|
||||||
|
"""Changes made after lock_tables exits (before endpoint returns) are committed.
|
||||||
|
|
||||||
|
Regression: without the auto-begin fix, lock_tables would start and commit a
|
||||||
|
real outer transaction, leaving the session idle. Any modifications after that
|
||||||
|
point were silently dropped.
|
||||||
|
"""
|
||||||
|
engine = create_async_engine(DATABASE_URL, echo=False)
|
||||||
|
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
try:
|
||||||
|
get_db = create_db_dependency(session_factory)
|
||||||
|
|
||||||
|
async for session in get_db():
|
||||||
|
async with lock_tables(session, [Role]):
|
||||||
|
role = Role(name="lock_then_update")
|
||||||
|
session.add(role)
|
||||||
|
await session.flush()
|
||||||
|
# lock_tables has exited — outer transaction must still be open
|
||||||
|
assert session.in_transaction()
|
||||||
|
role.name = "updated_after_lock"
|
||||||
|
|
||||||
|
async with session_factory() as verify:
|
||||||
|
result = await RoleCrud.first(
|
||||||
|
verify, [Role.name == "updated_after_lock"]
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
|
finally:
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.drop_all)
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
class TestCreateDbContext:
|
class TestCreateDbContext:
|
||||||
"""Tests for create_db_context."""
|
"""Tests for create_db_context."""
|
||||||
@@ -241,3 +301,182 @@ class TestLockTables:
|
|||||||
|
|
||||||
result = await RoleCrud.first(db_session, [Role.name == "lock_rollback_role"])
|
result = await RoleCrud.first(db_session, [Role.name == "lock_rollback_role"])
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestWaitForRowChange:
|
||||||
|
"""Tests for wait_for_row_change polling function."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_detects_update(self, db_session: AsyncSession, engine):
|
||||||
|
"""Returns updated instance when a column value changes."""
|
||||||
|
role = Role(name="watch_role")
|
||||||
|
db_session.add(role)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
async def update_later():
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
async with factory() as other:
|
||||||
|
r = await other.get(Role, role.id)
|
||||||
|
assert r is not None
|
||||||
|
r.name = "updated_role"
|
||||||
|
await other.commit()
|
||||||
|
|
||||||
|
update_task = asyncio.create_task(update_later())
|
||||||
|
result = await wait_for_row_change(db_session, Role, role.id, interval=0.05)
|
||||||
|
await update_task
|
||||||
|
|
||||||
|
assert result.name == "updated_role"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_watches_specific_columns(self, db_session: AsyncSession, engine):
|
||||||
|
"""Only triggers on changes to specified columns."""
|
||||||
|
user = User(username="testuser", email="test@example.com")
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
async def update_later():
|
||||||
|
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
# First: change email (not watched) — should not trigger
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
async with factory() as other:
|
||||||
|
u = await other.get(User, user.id)
|
||||||
|
assert u is not None
|
||||||
|
u.email = "new@example.com"
|
||||||
|
await other.commit()
|
||||||
|
# Second: change username (watched) — should trigger
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
async with factory() as other:
|
||||||
|
u = await other.get(User, user.id)
|
||||||
|
assert u is not None
|
||||||
|
u.username = "newuser"
|
||||||
|
await other.commit()
|
||||||
|
|
||||||
|
update_task = asyncio.create_task(update_later())
|
||||||
|
result = await wait_for_row_change(
|
||||||
|
db_session, User, user.id, columns=["username"], interval=0.05
|
||||||
|
)
|
||||||
|
await update_task
|
||||||
|
|
||||||
|
assert result.username == "newuser"
|
||||||
|
assert result.email == "new@example.com"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_nonexistent_row_raises(self, db_session: AsyncSession):
|
||||||
|
"""Raises NotFoundError when the row does not exist."""
|
||||||
|
fake_id = uuid.uuid4()
|
||||||
|
with pytest.raises(NotFoundError, match="not found"):
|
||||||
|
await wait_for_row_change(db_session, Role, fake_id, interval=0.05)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_timeout_raises(self, db_session: AsyncSession):
|
||||||
|
"""Raises TimeoutError when no change is detected within timeout."""
|
||||||
|
role = Role(name="timeout_role")
|
||||||
|
db_session.add(role)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
with pytest.raises(TimeoutError):
|
||||||
|
await wait_for_row_change(
|
||||||
|
db_session, Role, role.id, interval=0.05, timeout=0.2
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_deleted_row_raises(self, db_session: AsyncSession, engine):
|
||||||
|
"""Raises NotFoundError when the row is deleted during polling."""
|
||||||
|
role = Role(name="delete_role")
|
||||||
|
db_session.add(role)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
async def delete_later():
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
async with factory() as other:
|
||||||
|
r = await other.get(Role, role.id)
|
||||||
|
await other.delete(r)
|
||||||
|
await other.commit()
|
||||||
|
|
||||||
|
delete_task = asyncio.create_task(delete_later())
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
await wait_for_row_change(db_session, Role, role.id, interval=0.05)
|
||||||
|
await delete_task
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateDatabase:
|
||||||
|
"""Tests for create_database."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_creates_database(self):
|
||||||
|
"""Database is created by create_database."""
|
||||||
|
target_url = (
|
||||||
|
make_url(DATABASE_URL)
|
||||||
|
.set(database="test_create_db_general")
|
||||||
|
.render_as_string(hide_password=False)
|
||||||
|
)
|
||||||
|
expected_db = make_url(target_url).database
|
||||||
|
assert expected_db is not None
|
||||||
|
|
||||||
|
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||||
|
try:
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
await conn.execute(text(f"DROP DATABASE IF EXISTS {expected_db}"))
|
||||||
|
|
||||||
|
await create_database(db_name=expected_db, server_url=DATABASE_URL)
|
||||||
|
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
result = await conn.execute(
|
||||||
|
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||||
|
{"name": expected_db},
|
||||||
|
)
|
||||||
|
assert result.scalar() == 1
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
await conn.execute(text(f"DROP DATABASE IF EXISTS {expected_db}"))
|
||||||
|
finally:
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
class TestCleanupTables:
|
||||||
|
"""Tests for cleanup_tables helper."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_truncates_all_tables(self):
|
||||||
|
"""All table rows are removed after cleanup_tables."""
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||||
|
role = Role(id=uuid.uuid4(), name="cleanup_role")
|
||||||
|
session.add(role)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
username="cleanup_user",
|
||||||
|
email="cleanup@test.com",
|
||||||
|
role_id=role.id,
|
||||||
|
)
|
||||||
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Verify rows exist
|
||||||
|
roles_count = await RoleCrud.count(session)
|
||||||
|
users_count = await UserCrud.count(session)
|
||||||
|
assert roles_count == 1
|
||||||
|
assert users_count == 1
|
||||||
|
|
||||||
|
await cleanup_tables(session, Base)
|
||||||
|
|
||||||
|
# Verify tables are empty
|
||||||
|
roles_count = await RoleCrud.count(session)
|
||||||
|
users_count = await UserCrud.count(session)
|
||||||
|
assert roles_count == 0
|
||||||
|
assert users_count == 0
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_noop_for_empty_metadata(self):
|
||||||
|
"""cleanup_tables does not raise when metadata has no tables."""
|
||||||
|
|
||||||
|
class EmptyBase(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||||
|
# Should not raise
|
||||||
|
await cleanup_tables(session, EmptyBase)
|
||||||
|
|||||||
277
tests/test_dependencies.py
Normal file
277
tests/test_dependencies.py
Normal file
@@ -0,0 +1,277 @@
|
|||||||
|
"""Tests for fastapi_toolsets.dependencies module."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import uuid
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Annotated, Any, cast
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.params import Depends
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from fastapi_toolsets.dependencies import (
|
||||||
|
BodyDependency,
|
||||||
|
PathDependency,
|
||||||
|
_unwrap_session_dep,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .conftest import Role, RoleCreate, RoleCrud, User
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""Mock session dependency for testing."""
|
||||||
|
yield None # type: ignore[misc] # ty:ignore[invalid-yield]
|
||||||
|
|
||||||
|
|
||||||
|
MockSessionDep = Annotated[AsyncSession, Depends(mock_get_db)]
|
||||||
|
|
||||||
|
|
||||||
|
class TestUnwrapSessionDep:
|
||||||
|
def test_plain_callable_returned_as_is(self):
|
||||||
|
"""Plain callable is returned unchanged."""
|
||||||
|
assert _unwrap_session_dep(mock_get_db) is mock_get_db
|
||||||
|
|
||||||
|
def test_annotated_with_depends_unwrapped(self):
|
||||||
|
"""Annotated form with Depends is unwrapped to the plain callable."""
|
||||||
|
assert _unwrap_session_dep(MockSessionDep) is mock_get_db
|
||||||
|
|
||||||
|
def test_annotated_without_depends_returned_as_is(self):
|
||||||
|
"""Annotated form with no Depends falls back to returning session_dep as-is."""
|
||||||
|
annotated_no_dep = Annotated[AsyncSession, "not_a_depends"]
|
||||||
|
assert _unwrap_session_dep(annotated_no_dep) is annotated_no_dep
|
||||||
|
|
||||||
|
|
||||||
|
class TestPathDependency:
|
||||||
|
"""Tests for PathDependency factory."""
|
||||||
|
|
||||||
|
def test_returns_depends_instance(self):
|
||||||
|
"""PathDependency returns a Depends instance."""
|
||||||
|
dep = PathDependency(Role, Role.id, session_dep=mock_get_db)
|
||||||
|
assert isinstance(dep, Depends)
|
||||||
|
|
||||||
|
def test_signature_has_default_param_name(self):
|
||||||
|
"""PathDependency uses model_field as default param name."""
|
||||||
|
dep = cast(Any, PathDependency(Role, Role.id, session_dep=mock_get_db))
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
params = list(sig.parameters.keys())
|
||||||
|
|
||||||
|
assert "role_id" in params
|
||||||
|
assert "session" in params
|
||||||
|
|
||||||
|
def test_signature_has_correct_type_annotation(self):
|
||||||
|
"""PathDependency uses field's python type for annotation."""
|
||||||
|
dep = cast(Any, PathDependency(Role, Role.id, session_dep=mock_get_db))
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
|
assert sig.parameters["role_id"].annotation == uuid.UUID
|
||||||
|
assert sig.parameters["session"].annotation == AsyncSession
|
||||||
|
|
||||||
|
def test_signature_session_has_depends_default(self):
|
||||||
|
"""PathDependency session param has Depends as default."""
|
||||||
|
dep = cast(Any, PathDependency(Role, Role.id, session_dep=mock_get_db))
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
|
assert isinstance(sig.parameters["session"].default, Depends)
|
||||||
|
|
||||||
|
def test_custom_param_name_in_signature(self):
|
||||||
|
"""PathDependency uses custom param_name in signature."""
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
PathDependency(
|
||||||
|
Role, Role.id, session_dep=mock_get_db, param_name="role_uuid"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
params = list(sig.parameters.keys())
|
||||||
|
|
||||||
|
assert "role_uuid" in params
|
||||||
|
assert "id" not in params
|
||||||
|
|
||||||
|
def test_string_field_type(self):
|
||||||
|
"""PathDependency handles string field types."""
|
||||||
|
dep = cast(Any, PathDependency(User, User.username, session_dep=mock_get_db))
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
|
assert sig.parameters["user_username"].annotation is str
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_dependency_fetches_object(self, db_session):
|
||||||
|
"""PathDependency inner function fetches object from database."""
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="test_role"))
|
||||||
|
|
||||||
|
dep = cast(Any, PathDependency(Role, Role.id, session_dep=mock_get_db))
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
result = await func(session=db_session, role_id=role.id)
|
||||||
|
|
||||||
|
assert result.id == role.id
|
||||||
|
assert result.name == "test_role"
|
||||||
|
|
||||||
|
def test_annotated_session_dep_returns_depends_instance(self):
|
||||||
|
"""PathDependency accepts Annotated[AsyncSession, Depends(...)] form."""
|
||||||
|
dep = PathDependency(Role, Role.id, session_dep=MockSessionDep)
|
||||||
|
assert isinstance(dep, Depends)
|
||||||
|
|
||||||
|
def test_annotated_session_dep_signature(self):
|
||||||
|
"""PathDependency with Annotated session_dep produces a valid signature."""
|
||||||
|
dep = cast(Any, PathDependency(Role, Role.id, session_dep=MockSessionDep))
|
||||||
|
sig = inspect.signature(dep.dependency)
|
||||||
|
|
||||||
|
assert "role_id" in sig.parameters
|
||||||
|
assert "session" in sig.parameters
|
||||||
|
assert isinstance(sig.parameters["session"].default, Depends)
|
||||||
|
|
||||||
|
def test_annotated_session_dep_unwraps_callable(self):
|
||||||
|
"""PathDependency with Annotated form uses the underlying callable, not the Annotated type."""
|
||||||
|
dep = cast(Any, PathDependency(Role, Role.id, session_dep=MockSessionDep))
|
||||||
|
sig = inspect.signature(dep.dependency)
|
||||||
|
|
||||||
|
inner_dep = sig.parameters["session"].default
|
||||||
|
assert inner_dep.dependency is mock_get_db
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_annotated_session_dep_fetches_object(self, db_session):
|
||||||
|
"""PathDependency with Annotated session_dep correctly fetches object from database."""
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="annotated_role"))
|
||||||
|
|
||||||
|
dep = cast(Any, PathDependency(Role, Role.id, session_dep=MockSessionDep))
|
||||||
|
result = await dep.dependency(session=db_session, role_id=role.id)
|
||||||
|
|
||||||
|
assert result.id == role.id
|
||||||
|
assert result.name == "annotated_role"
|
||||||
|
|
||||||
|
|
||||||
|
class TestBodyDependency:
|
||||||
|
"""Tests for BodyDependency factory."""
|
||||||
|
|
||||||
|
def test_returns_depends_instance(self):
|
||||||
|
"""BodyDependency returns a Depends instance."""
|
||||||
|
dep = BodyDependency(
|
||||||
|
Role, Role.id, session_dep=mock_get_db, body_field="role_id"
|
||||||
|
)
|
||||||
|
assert isinstance(dep, Depends)
|
||||||
|
|
||||||
|
def test_signature_has_body_field_as_param(self):
|
||||||
|
"""BodyDependency uses body_field as param name."""
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
BodyDependency(
|
||||||
|
Role, Role.id, session_dep=mock_get_db, body_field="role_id"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
params = list(sig.parameters.keys())
|
||||||
|
|
||||||
|
assert "role_id" in params
|
||||||
|
assert "session" in params
|
||||||
|
|
||||||
|
def test_signature_has_correct_type_annotation(self):
|
||||||
|
"""BodyDependency uses field's python type for annotation."""
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
BodyDependency(
|
||||||
|
Role, Role.id, session_dep=mock_get_db, body_field="role_id"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
|
assert sig.parameters["role_id"].annotation == uuid.UUID
|
||||||
|
assert sig.parameters["session"].annotation == AsyncSession
|
||||||
|
|
||||||
|
def test_signature_session_has_depends_default(self):
|
||||||
|
"""BodyDependency session param has Depends as default."""
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
BodyDependency(
|
||||||
|
Role, Role.id, session_dep=mock_get_db, body_field="role_id"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
|
assert isinstance(sig.parameters["session"].default, Depends)
|
||||||
|
|
||||||
|
def test_different_body_field_name(self):
|
||||||
|
"""BodyDependency can use any body_field name."""
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
BodyDependency(
|
||||||
|
User, User.id, session_dep=mock_get_db, body_field="user_uuid"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
params = list(sig.parameters.keys())
|
||||||
|
|
||||||
|
assert "user_uuid" in params
|
||||||
|
assert "id" not in params
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_dependency_fetches_object(self, db_session):
|
||||||
|
"""BodyDependency inner function fetches object from database."""
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="body_test_role"))
|
||||||
|
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
BodyDependency(
|
||||||
|
Role, Role.id, session_dep=mock_get_db, body_field="role_id"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
result = await func(session=db_session, role_id=role.id)
|
||||||
|
|
||||||
|
assert result.id == role.id
|
||||||
|
assert result.name == "body_test_role"
|
||||||
|
|
||||||
|
def test_annotated_session_dep_returns_depends_instance(self):
|
||||||
|
"""BodyDependency accepts Annotated[AsyncSession, Depends(...)] form."""
|
||||||
|
dep = BodyDependency(
|
||||||
|
Role, Role.id, session_dep=MockSessionDep, body_field="role_id"
|
||||||
|
)
|
||||||
|
assert isinstance(dep, Depends)
|
||||||
|
|
||||||
|
def test_annotated_session_dep_unwraps_callable(self):
|
||||||
|
"""BodyDependency with Annotated form uses the underlying callable, not the Annotated type."""
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
BodyDependency(
|
||||||
|
Role, Role.id, session_dep=MockSessionDep, body_field="role_id"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
sig = inspect.signature(dep.dependency)
|
||||||
|
|
||||||
|
inner_dep = sig.parameters["session"].default
|
||||||
|
assert inner_dep.dependency is mock_get_db
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_annotated_session_dep_fetches_object(self, db_session):
|
||||||
|
"""BodyDependency with Annotated session_dep correctly fetches object from database."""
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="body_annotated_role"))
|
||||||
|
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
BodyDependency(
|
||||||
|
Role, Role.id, session_dep=MockSessionDep, body_field="role_id"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
result = await dep.dependency(session=db_session, role_id=role.id)
|
||||||
|
|
||||||
|
assert result.id == role.id
|
||||||
|
assert result.name == "body_annotated_role"
|
||||||
485
tests/test_example_pagination_search.py
Normal file
485
tests/test_example_pagination_search.py
Normal file
@@ -0,0 +1,485 @@
|
|||||||
|
"""Live test for the docs/examples/pagination-search.md example.
|
||||||
|
|
||||||
|
Spins up the exact FastAPI app described in the example (sourced from
|
||||||
|
docs_src/examples/pagination_search/) and exercises it through a real HTTP
|
||||||
|
client against a real PostgreSQL database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from docs_src.examples.pagination_search.db import get_db
|
||||||
|
from docs_src.examples.pagination_search.models import Article, Base, Category
|
||||||
|
from docs_src.examples.pagination_search.routes import router
|
||||||
|
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
||||||
|
from fastapi_toolsets.pytest import create_db_session
|
||||||
|
|
||||||
|
from .conftest import DATABASE_URL
|
||||||
|
|
||||||
|
|
||||||
|
def build_app(session: AsyncSession) -> FastAPI:
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
async def override_get_db():
|
||||||
|
yield session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
app.include_router(router)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
async def ex_db_session():
|
||||||
|
"""Isolated session for the example models (separate tables from conftest)."""
|
||||||
|
async with create_db_session(DATABASE_URL, Base) as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def client(ex_db_session: AsyncSession):
|
||||||
|
app = build_app(ex_db_session)
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=app), base_url="http://test"
|
||||||
|
) as ac:
|
||||||
|
yield ac
|
||||||
|
|
||||||
|
|
||||||
|
async def seed(session: AsyncSession):
|
||||||
|
"""Insert representative fixture data."""
|
||||||
|
python = Category(name="python")
|
||||||
|
backend = Category(name="backend")
|
||||||
|
session.add_all([python, backend])
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
now = datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc)
|
||||||
|
session.add_all(
|
||||||
|
[
|
||||||
|
Article(
|
||||||
|
title="FastAPI tips",
|
||||||
|
body="Ten useful tips for FastAPI.",
|
||||||
|
status="published",
|
||||||
|
published=True,
|
||||||
|
category_id=python.id,
|
||||||
|
created_at=now,
|
||||||
|
),
|
||||||
|
Article(
|
||||||
|
title="SQLAlchemy async",
|
||||||
|
body="How to use async SQLAlchemy.",
|
||||||
|
status="published",
|
||||||
|
published=True,
|
||||||
|
category_id=backend.id,
|
||||||
|
created_at=now + datetime.timedelta(seconds=1),
|
||||||
|
),
|
||||||
|
Article(
|
||||||
|
title="Draft notes",
|
||||||
|
body="Work in progress.",
|
||||||
|
status="draft",
|
||||||
|
published=False,
|
||||||
|
category_id=None,
|
||||||
|
created_at=now + datetime.timedelta(seconds=2),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
class TestAppSessionDep:
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_get_db_yields_async_session(self):
|
||||||
|
"""get_db yields a real AsyncSession when called directly."""
|
||||||
|
from docs_src.examples.pagination_search.db import get_db
|
||||||
|
|
||||||
|
gen = get_db()
|
||||||
|
session = await gen.__anext__()
|
||||||
|
assert isinstance(session, AsyncSession)
|
||||||
|
await gen.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
class TestOffsetPagination:
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_returns_all_articles(self, client: AsyncClient, ex_db_session):
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/offset")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["pagination"]["total_count"] == 3
|
||||||
|
assert len(body["data"]) == 3
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_pagination_page_size(self, client: AsyncClient, ex_db_session):
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/offset?items_per_page=2&page=1")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert len(body["data"]) == 2
|
||||||
|
assert body["pagination"]["total_count"] == 3
|
||||||
|
assert body["pagination"]["has_more"] is True
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_full_text_search(self, client: AsyncClient, ex_db_session):
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/offset?search=fastapi")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["pagination"]["total_count"] == 1
|
||||||
|
assert body["data"][0]["title"] == "FastAPI tips"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_traverses_relationship(
|
||||||
|
self, client: AsyncClient, ex_db_session
|
||||||
|
):
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
# "python" matches Category.name, not Article.title or body
|
||||||
|
resp = await client.get("/articles/offset?search=python")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["pagination"]["total_count"] == 1
|
||||||
|
assert body["data"][0]["title"] == "FastAPI tips"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_facet_filter_scalar(self, client: AsyncClient, ex_db_session):
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/offset?status=published")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["pagination"]["total_count"] == 2
|
||||||
|
assert all(a["status"] == "published" for a in body["data"])
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_facet_filter_multi_value(self, client: AsyncClient, ex_db_session):
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/offset?status=published&status=draft")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["pagination"]["total_count"] == 3
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_filter_attributes_in_response(
|
||||||
|
self, client: AsyncClient, ex_db_session
|
||||||
|
):
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/offset")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
fa = body["filter_attributes"]
|
||||||
|
assert set(fa["status"]) == {"draft", "published"}
|
||||||
|
assert set(fa["category__name"]) == {"backend", "python"}
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_filter_attributes_scoped_to_filter(
|
||||||
|
self, client: AsyncClient, ex_db_session
|
||||||
|
):
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/offset?status=published")
|
||||||
|
|
||||||
|
body = resp.json()
|
||||||
|
# draft is filtered out → should not appear in filter_attributes
|
||||||
|
assert "draft" not in body["filter_attributes"]["status"]
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_and_filter_combined(self, client: AsyncClient, ex_db_session):
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/offset?search=async&status=published")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["pagination"]["total_count"] == 1
|
||||||
|
assert body["data"][0]["title"] == "SQLAlchemy async"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCursorPagination:
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_first_page(self, client: AsyncClient, ex_db_session):
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/cursor?items_per_page=2")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert len(body["data"]) == 2
|
||||||
|
assert body["pagination"]["has_more"] is True
|
||||||
|
assert body["pagination"]["next_cursor"] is not None
|
||||||
|
assert body["pagination"]["prev_cursor"] is None
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_second_page(self, client: AsyncClient, ex_db_session):
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
first = await client.get("/articles/cursor?items_per_page=2")
|
||||||
|
next_cursor = first.json()["pagination"]["next_cursor"]
|
||||||
|
|
||||||
|
resp = await client.get(
|
||||||
|
f"/articles/cursor?items_per_page=2&cursor={next_cursor}"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert len(body["data"]) == 1
|
||||||
|
assert body["pagination"]["has_more"] is False
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_facet_filter(self, client: AsyncClient, ex_db_session):
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/cursor?status=draft")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert len(body["data"]) == 1
|
||||||
|
assert body["data"][0]["status"] == "draft"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_full_text_search(self, client: AsyncClient, ex_db_session):
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/cursor?search=sqlalchemy")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert len(body["data"]) == 1
|
||||||
|
assert body["data"][0]["title"] == "SQLAlchemy async"
|
||||||
|
|
||||||
|
|
||||||
|
class TestOffsetSorting:
|
||||||
|
"""Tests for order_by / order query parameters on the offset endpoint."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_default_order_uses_created_at_asc(
|
||||||
|
self, client: AsyncClient, ex_db_session
|
||||||
|
):
|
||||||
|
"""No order_by → default field (created_at) ASC."""
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/offset")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
titles = [a["title"] for a in resp.json()["data"]]
|
||||||
|
assert titles == ["FastAPI tips", "SQLAlchemy async", "Draft notes"]
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_order_by_title_asc(self, client: AsyncClient, ex_db_session):
|
||||||
|
"""order_by=title&order=asc returns alphabetical order."""
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/offset?order_by=title&order=asc")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
titles = [a["title"] for a in resp.json()["data"]]
|
||||||
|
assert titles == ["Draft notes", "FastAPI tips", "SQLAlchemy async"]
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_order_by_title_desc(self, client: AsyncClient, ex_db_session):
|
||||||
|
"""order_by=title&order=desc returns reverse alphabetical order."""
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/offset?order_by=title&order=desc")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
titles = [a["title"] for a in resp.json()["data"]]
|
||||||
|
assert titles == ["SQLAlchemy async", "FastAPI tips", "Draft notes"]
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_order_by_created_at_desc(self, client: AsyncClient, ex_db_session):
|
||||||
|
"""order_by=created_at&order=desc returns newest-first."""
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/offset?order_by=created_at&order=desc")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
titles = [a["title"] for a in resp.json()["data"]]
|
||||||
|
assert titles == ["Draft notes", "SQLAlchemy async", "FastAPI tips"]
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_invalid_order_by_returns_422(
|
||||||
|
self, client: AsyncClient, ex_db_session
|
||||||
|
):
|
||||||
|
"""Unknown order_by field returns 422 with SORT-422 error code."""
|
||||||
|
resp = await client.get("/articles/offset?order_by=nonexistent_field")
|
||||||
|
|
||||||
|
assert resp.status_code == 422
|
||||||
|
body = resp.json()
|
||||||
|
assert body["error_code"] == "SORT-422"
|
||||||
|
assert body["status"] == "FAIL"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCursorSorting:
|
||||||
|
"""Tests for order_by / order query parameters on the cursor endpoint.
|
||||||
|
|
||||||
|
In cursor_paginate the cursor_column is always the primary sort; order_by
|
||||||
|
acts as a secondary tiebreaker. With the seeded articles (all having unique
|
||||||
|
created_at values) the overall ordering is always created_at ASC regardless
|
||||||
|
of the order_by value — only the valid/invalid field check and the response
|
||||||
|
shape are meaningful here.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_default_order_uses_created_at_asc(
|
||||||
|
self, client: AsyncClient, ex_db_session
|
||||||
|
):
|
||||||
|
"""No order_by → default field (created_at) ASC."""
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/cursor")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
titles = [a["title"] for a in resp.json()["data"]]
|
||||||
|
assert titles == ["FastAPI tips", "SQLAlchemy async", "Draft notes"]
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_order_by_title_asc_accepted(
|
||||||
|
self, client: AsyncClient, ex_db_session
|
||||||
|
):
|
||||||
|
"""order_by=title is a valid field — request succeeds and returns all articles."""
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/cursor?order_by=title&order=asc")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert len(resp.json()["data"]) == 3
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_order_by_title_desc_accepted(
|
||||||
|
self, client: AsyncClient, ex_db_session
|
||||||
|
):
|
||||||
|
"""order_by=title&order=desc is valid — request succeeds and returns all articles."""
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/cursor?order_by=title&order=desc")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert len(resp.json()["data"]) == 3
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_invalid_order_by_returns_422(
|
||||||
|
self, client: AsyncClient, ex_db_session
|
||||||
|
):
|
||||||
|
"""Unknown order_by field returns 422 with SORT-422 error code."""
|
||||||
|
resp = await client.get("/articles/cursor?order_by=nonexistent_field")
|
||||||
|
|
||||||
|
assert resp.status_code == 422
|
||||||
|
body = resp.json()
|
||||||
|
assert body["error_code"] == "SORT-422"
|
||||||
|
assert body["status"] == "FAIL"
|
||||||
|
|
||||||
|
|
||||||
|
class TestPaginateUnified:
|
||||||
|
"""Tests for the unified GET /articles/ endpoint using paginate()."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_defaults_to_offset_pagination(
|
||||||
|
self, client: AsyncClient, ex_db_session
|
||||||
|
):
|
||||||
|
"""Without pagination_type, defaults to offset pagination."""
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["pagination_type"] == "offset"
|
||||||
|
assert "total_count" in body["pagination"]
|
||||||
|
assert body["pagination"]["total_count"] == 3
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_explicit_offset_pagination(self, client: AsyncClient, ex_db_session):
|
||||||
|
"""pagination_type=offset returns OffsetPagination metadata."""
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get(
|
||||||
|
"/articles/?pagination_type=offset&page=1&items_per_page=2"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["pagination_type"] == "offset"
|
||||||
|
assert body["pagination"]["total_count"] == 3
|
||||||
|
assert body["pagination"]["page"] == 1
|
||||||
|
assert body["pagination"]["has_more"] is True
|
||||||
|
assert len(body["data"]) == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_cursor_pagination_type(self, client: AsyncClient, ex_db_session):
|
||||||
|
"""pagination_type=cursor returns CursorPagination metadata."""
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/?pagination_type=cursor&items_per_page=2")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["pagination_type"] == "cursor"
|
||||||
|
assert "next_cursor" in body["pagination"]
|
||||||
|
assert "total_count" not in body["pagination"]
|
||||||
|
assert body["pagination"]["has_more"] is True
|
||||||
|
assert len(body["data"]) == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_cursor_pagination_navigate_pages(
|
||||||
|
self, client: AsyncClient, ex_db_session
|
||||||
|
):
|
||||||
|
"""Cursor from first page can be used to fetch the next page."""
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
first = await client.get("/articles/?pagination_type=cursor&items_per_page=2")
|
||||||
|
assert first.status_code == 200
|
||||||
|
first_body = first.json()
|
||||||
|
next_cursor = first_body["pagination"]["next_cursor"]
|
||||||
|
assert next_cursor is not None
|
||||||
|
|
||||||
|
second = await client.get(
|
||||||
|
f"/articles/?pagination_type=cursor&items_per_page=2&cursor={next_cursor}"
|
||||||
|
)
|
||||||
|
assert second.status_code == 200
|
||||||
|
second_body = second.json()
|
||||||
|
assert second_body["pagination_type"] == "cursor"
|
||||||
|
assert second_body["pagination"]["has_more"] is False
|
||||||
|
assert len(second_body["data"]) == 1
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_cursor_pagination_with_search(
|
||||||
|
self, client: AsyncClient, ex_db_session
|
||||||
|
):
|
||||||
|
"""paginate() with cursor type respects search parameter."""
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/?pagination_type=cursor&search=fastapi")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["pagination_type"] == "cursor"
|
||||||
|
assert len(body["data"]) == 1
|
||||||
|
assert body["data"][0]["title"] == "FastAPI tips"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_offset_pagination_with_filter(
|
||||||
|
self, client: AsyncClient, ex_db_session
|
||||||
|
):
|
||||||
|
"""paginate() with offset type respects filter_by parameter."""
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/?pagination_type=offset&status=published")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["pagination_type"] == "offset"
|
||||||
|
assert body["pagination"]["total_count"] == 2
|
||||||
@@ -2,12 +2,14 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
from fastapi.exceptions import HTTPException
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from fastapi_toolsets.exceptions import (
|
from fastapi_toolsets.exceptions import (
|
||||||
ApiException,
|
ApiException,
|
||||||
ConflictError,
|
ConflictError,
|
||||||
ForbiddenError,
|
ForbiddenError,
|
||||||
|
InvalidOrderFieldError,
|
||||||
NotFoundError,
|
NotFoundError,
|
||||||
UnauthorizedError,
|
UnauthorizedError,
|
||||||
generate_error_responses,
|
generate_error_responses,
|
||||||
@@ -35,8 +37,8 @@ class TestApiException:
|
|||||||
assert error.api_error.msg == "I'm a teapot"
|
assert error.api_error.msg == "I'm a teapot"
|
||||||
assert str(error) == "I'm a teapot"
|
assert str(error) == "I'm a teapot"
|
||||||
|
|
||||||
def test_custom_detail_message(self):
|
def test_detail_overrides_msg_and_str(self):
|
||||||
"""Custom detail overrides default message."""
|
"""detail sets both str(exc) and api_error.msg; class-level msg is unchanged."""
|
||||||
|
|
||||||
class CustomError(ApiException):
|
class CustomError(ApiException):
|
||||||
api_error = ApiError(
|
api_error = ApiError(
|
||||||
@@ -46,8 +48,172 @@ class TestApiException:
|
|||||||
err_code="BAD-400",
|
err_code="BAD-400",
|
||||||
)
|
)
|
||||||
|
|
||||||
error = CustomError("Custom message")
|
error = CustomError("Widget not found")
|
||||||
assert str(error) == "Custom message"
|
assert str(error) == "Widget not found"
|
||||||
|
assert error.api_error.msg == "Widget not found"
|
||||||
|
assert CustomError.api_error.msg == "Bad Request" # class unchanged
|
||||||
|
|
||||||
|
def test_desc_override(self):
|
||||||
|
"""desc kwarg overrides api_error.desc on the instance only."""
|
||||||
|
|
||||||
|
class MyError(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400, msg="Error", desc="Default.", err_code="ERR-400"
|
||||||
|
)
|
||||||
|
|
||||||
|
err = MyError(desc="Custom desc.")
|
||||||
|
assert err.api_error.desc == "Custom desc."
|
||||||
|
assert MyError.api_error.desc == "Default." # class unchanged
|
||||||
|
|
||||||
|
def test_data_override(self):
|
||||||
|
"""data kwarg sets api_error.data on the instance only."""
|
||||||
|
|
||||||
|
class MyError(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400, msg="Error", desc="Default.", err_code="ERR-400"
|
||||||
|
)
|
||||||
|
|
||||||
|
err = MyError(data={"key": "value"})
|
||||||
|
assert err.api_error.data == {"key": "value"}
|
||||||
|
assert MyError.api_error.data is None # class unchanged
|
||||||
|
|
||||||
|
def test_desc_and_data_override(self):
|
||||||
|
"""detail, desc and data can all be overridden together."""
|
||||||
|
|
||||||
|
class MyError(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400, msg="Error", desc="Default.", err_code="ERR-400"
|
||||||
|
)
|
||||||
|
|
||||||
|
err = MyError("custom msg", desc="New desc.", data={"x": 1})
|
||||||
|
assert str(err) == "custom msg"
|
||||||
|
assert err.api_error.msg == "custom msg" # detail also updates msg
|
||||||
|
assert err.api_error.desc == "New desc."
|
||||||
|
assert err.api_error.data == {"x": 1}
|
||||||
|
assert err.api_error.code == 400 # other fields unchanged
|
||||||
|
|
||||||
|
def test_class_api_error_not_mutated_after_instance_override(self):
|
||||||
|
"""Raising with desc/data does not mutate the class-level api_error."""
|
||||||
|
|
||||||
|
class MyError(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400, msg="Error", desc="Default.", err_code="ERR-400"
|
||||||
|
)
|
||||||
|
|
||||||
|
MyError(desc="Changed", data={"x": 1})
|
||||||
|
assert MyError.api_error.desc == "Default."
|
||||||
|
assert MyError.api_error.data is None
|
||||||
|
|
||||||
|
def test_subclass_uses_super_with_desc_and_data(self):
|
||||||
|
"""Subclasses can delegate detail/desc/data to super().__init__()."""
|
||||||
|
|
||||||
|
class BuildValidationError(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=422,
|
||||||
|
msg="Build Validation Error",
|
||||||
|
desc="The build configuration is invalid.",
|
||||||
|
err_code="BUILD-422",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, *errors: str) -> None:
|
||||||
|
super().__init__(
|
||||||
|
f"{len(errors)} validation error(s)",
|
||||||
|
desc=", ".join(errors),
|
||||||
|
data={"errors": [{"message": e} for e in errors]},
|
||||||
|
)
|
||||||
|
|
||||||
|
err = BuildValidationError("Field A is required", "Field B is invalid")
|
||||||
|
assert str(err) == "2 validation error(s)"
|
||||||
|
assert err.api_error.msg == "2 validation error(s)" # detail set msg
|
||||||
|
assert err.api_error.desc == "Field A is required, Field B is invalid"
|
||||||
|
assert err.api_error.data == {
|
||||||
|
"errors": [
|
||||||
|
{"message": "Field A is required"},
|
||||||
|
{"message": "Field B is invalid"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
assert err.api_error.code == 422 # other fields unchanged
|
||||||
|
|
||||||
|
def test_detail_desc_data_in_http_response(self):
|
||||||
|
"""detail/desc/data overrides all appear correctly in the FastAPI HTTP response."""
|
||||||
|
|
||||||
|
class DynamicError(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400, msg="Error", desc="Default.", err_code="ERR-400"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, message: str) -> None:
|
||||||
|
super().__init__(
|
||||||
|
message,
|
||||||
|
desc=f"Detail: {message}",
|
||||||
|
data={"reason": message},
|
||||||
|
)
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
@app.get("/error")
|
||||||
|
async def raise_error():
|
||||||
|
raise DynamicError("something went wrong")
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/error")
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
body = response.json()
|
||||||
|
assert body["message"] == "something went wrong"
|
||||||
|
assert body["description"] == "Detail: something went wrong"
|
||||||
|
assert body["data"] == {"reason": "something went wrong"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestApiExceptionGuard:
|
||||||
|
"""Tests for the __init_subclass__ api_error guard."""
|
||||||
|
|
||||||
|
def test_missing_api_error_raises_type_error(self):
|
||||||
|
"""Defining a subclass without api_error raises TypeError at class creation time."""
|
||||||
|
with pytest.raises(
|
||||||
|
TypeError, match="must define an 'api_error' class attribute"
|
||||||
|
):
|
||||||
|
|
||||||
|
class BrokenError(ApiException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_abstract_subclass_skips_guard(self):
|
||||||
|
"""abstract=True allows intermediate base classes without api_error."""
|
||||||
|
|
||||||
|
class BaseGroupError(ApiException, abstract=True):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Concrete child must still define it
|
||||||
|
class ConcreteError(BaseGroupError):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400, msg="Error", desc="Desc.", err_code="ERR-400"
|
||||||
|
)
|
||||||
|
|
||||||
|
err = ConcreteError()
|
||||||
|
assert err.api_error.code == 400
|
||||||
|
|
||||||
|
def test_abstract_child_still_requires_api_error_on_concrete(self):
|
||||||
|
"""Concrete subclass of an abstract class must define api_error."""
|
||||||
|
|
||||||
|
class Base(ApiException, abstract=True):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
TypeError, match="must define an 'api_error' class attribute"
|
||||||
|
):
|
||||||
|
|
||||||
|
class Concrete(Base):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_inherited_api_error_satisfies_guard(self):
|
||||||
|
"""Subclass that inherits api_error from a parent does not need its own."""
|
||||||
|
|
||||||
|
class ConcreteError(NotFoundError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
err = ConcreteError()
|
||||||
|
assert err.api_error.code == 404
|
||||||
|
|
||||||
|
|
||||||
class TestBuiltInExceptions:
|
class TestBuiltInExceptions:
|
||||||
@@ -89,7 +255,7 @@ class TestGenerateErrorResponses:
|
|||||||
assert responses[404]["description"] == "Not Found"
|
assert responses[404]["description"] == "Not Found"
|
||||||
|
|
||||||
def test_generates_multiple_responses(self):
|
def test_generates_multiple_responses(self):
|
||||||
"""Generates responses for multiple exceptions."""
|
"""Generates responses for multiple exceptions with distinct status codes."""
|
||||||
responses = generate_error_responses(
|
responses = generate_error_responses(
|
||||||
UnauthorizedError,
|
UnauthorizedError,
|
||||||
ForbiddenError,
|
ForbiddenError,
|
||||||
@@ -100,14 +266,81 @@ class TestGenerateErrorResponses:
|
|||||||
assert 403 in responses
|
assert 403 in responses
|
||||||
assert 404 in responses
|
assert 404 in responses
|
||||||
|
|
||||||
def test_response_has_example(self):
|
def test_response_has_named_example(self):
|
||||||
"""Generated response includes example."""
|
"""Generated response uses named examples keyed by err_code."""
|
||||||
responses = generate_error_responses(NotFoundError)
|
responses = generate_error_responses(NotFoundError)
|
||||||
example = responses[404]["content"]["application/json"]["example"]
|
examples = responses[404]["content"]["application/json"]["examples"]
|
||||||
|
|
||||||
assert example["status"] == "FAIL"
|
assert "RES-404" in examples
|
||||||
assert example["error_code"] == "RES-404"
|
value = examples["RES-404"]["value"]
|
||||||
assert example["message"] == "Not Found"
|
assert value["status"] == "FAIL"
|
||||||
|
assert value["error_code"] == "RES-404"
|
||||||
|
assert value["message"] == "Not Found"
|
||||||
|
assert value["data"] is None
|
||||||
|
|
||||||
|
def test_response_example_has_summary(self):
|
||||||
|
"""Each named example carries a summary equal to api_error.msg."""
|
||||||
|
responses = generate_error_responses(NotFoundError)
|
||||||
|
example = responses[404]["content"]["application/json"]["examples"]["RES-404"]
|
||||||
|
|
||||||
|
assert example["summary"] == "Not Found"
|
||||||
|
|
||||||
|
def test_response_example_with_data(self):
|
||||||
|
"""Generated response includes data when set on ApiError."""
|
||||||
|
|
||||||
|
class ErrorWithData(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400,
|
||||||
|
msg="Bad Request",
|
||||||
|
desc="Invalid input.",
|
||||||
|
err_code="BAD-400",
|
||||||
|
data={"details": "some context"},
|
||||||
|
)
|
||||||
|
|
||||||
|
responses = generate_error_responses(ErrorWithData)
|
||||||
|
value = responses[400]["content"]["application/json"]["examples"]["BAD-400"][
|
||||||
|
"value"
|
||||||
|
]
|
||||||
|
|
||||||
|
assert value["data"] == {"details": "some context"}
|
||||||
|
|
||||||
|
def test_two_errors_same_code_both_present(self):
|
||||||
|
"""Two exceptions with the same HTTP code produce two named examples."""
|
||||||
|
|
||||||
|
class BadRequestA(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400, msg="Bad A", desc="Reason A.", err_code="ERR-A"
|
||||||
|
)
|
||||||
|
|
||||||
|
class BadRequestB(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400, msg="Bad B", desc="Reason B.", err_code="ERR-B"
|
||||||
|
)
|
||||||
|
|
||||||
|
responses = generate_error_responses(BadRequestA, BadRequestB)
|
||||||
|
|
||||||
|
assert 400 in responses
|
||||||
|
examples = responses[400]["content"]["application/json"]["examples"]
|
||||||
|
assert "ERR-A" in examples
|
||||||
|
assert "ERR-B" in examples
|
||||||
|
assert examples["ERR-A"]["value"]["message"] == "Bad A"
|
||||||
|
assert examples["ERR-B"]["value"]["message"] == "Bad B"
|
||||||
|
|
||||||
|
def test_two_errors_same_code_single_top_level_entry(self):
|
||||||
|
"""Two exceptions with the same HTTP code produce exactly one top-level entry."""
|
||||||
|
|
||||||
|
class BadRequestA(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400, msg="Bad A", desc="Reason A.", err_code="ERR-A"
|
||||||
|
)
|
||||||
|
|
||||||
|
class BadRequestB(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400, msg="Bad B", desc="Reason B.", err_code="ERR-B"
|
||||||
|
)
|
||||||
|
|
||||||
|
responses = generate_error_responses(BadRequestA, BadRequestB)
|
||||||
|
assert len([k for k in responses if k == 400]) == 1
|
||||||
|
|
||||||
|
|
||||||
class TestInitExceptionsHandlers:
|
class TestInitExceptionsHandlers:
|
||||||
@@ -137,6 +370,59 @@ class TestInitExceptionsHandlers:
|
|||||||
assert data["error_code"] == "RES-404"
|
assert data["error_code"] == "RES-404"
|
||||||
assert data["message"] == "Not Found"
|
assert data["message"] == "Not Found"
|
||||||
|
|
||||||
|
def test_handles_api_exception_without_data(self):
|
||||||
|
"""ApiException without data returns null data field."""
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
@app.get("/error")
|
||||||
|
async def raise_error():
|
||||||
|
raise NotFoundError()
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/error")
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
assert response.json()["data"] is None
|
||||||
|
|
||||||
|
def test_handles_api_exception_with_data(self):
|
||||||
|
"""ApiException with data returns the data payload."""
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
class CustomValidationError(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=422,
|
||||||
|
msg="Validation Error",
|
||||||
|
desc="1 validation error(s) detected",
|
||||||
|
err_code="CUSTOM-422",
|
||||||
|
data={
|
||||||
|
"errors": [
|
||||||
|
{
|
||||||
|
"field": "email",
|
||||||
|
"message": "invalid format",
|
||||||
|
"type": "value_error",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get("/error")
|
||||||
|
async def raise_error():
|
||||||
|
raise CustomValidationError()
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/error")
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
data = response.json()
|
||||||
|
assert data["data"] == {
|
||||||
|
"errors": [
|
||||||
|
{"field": "email", "message": "invalid format", "type": "value_error"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
assert data["error_code"] == "CUSTOM-422"
|
||||||
|
|
||||||
def test_handles_validation_error(self):
|
def test_handles_validation_error(self):
|
||||||
"""Handles validation errors with structured response."""
|
"""Handles validation errors with structured response."""
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -178,13 +464,68 @@ class TestInitExceptionsHandlers:
|
|||||||
assert data["status"] == "FAIL"
|
assert data["status"] == "FAIL"
|
||||||
assert data["error_code"] == "SERVER-500"
|
assert data["error_code"] == "SERVER-500"
|
||||||
|
|
||||||
def test_custom_openapi_schema(self):
|
def test_handles_http_exception(self):
|
||||||
"""Customizes OpenAPI schema for 422 responses."""
|
"""Handles starlette HTTPException with consistent ErrorResponse envelope."""
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
init_exceptions_handlers(app)
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
@app.get("/protected")
|
||||||
|
async def protected():
|
||||||
|
raise HTTPException(status_code=403, detail="Forbidden")
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/protected")
|
||||||
|
|
||||||
|
assert response.status_code == 403
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "FAIL"
|
||||||
|
assert data["error_code"] == "HTTP-403"
|
||||||
|
assert data["message"] == "Forbidden"
|
||||||
|
|
||||||
|
def test_handles_http_exception_404_from_route(self):
|
||||||
|
"""HTTPException(404) raised inside a route uses the consistent ErrorResponse envelope."""
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
@app.get("/items/{item_id}")
|
||||||
|
async def get_item(item_id: int):
|
||||||
|
raise HTTPException(status_code=404, detail="Item not found")
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/items/99")
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "FAIL"
|
||||||
|
assert data["error_code"] == "HTTP-404"
|
||||||
|
assert data["message"] == "Item not found"
|
||||||
|
|
||||||
|
def test_handles_http_exception_forwards_headers(self):
|
||||||
|
"""HTTPException with WWW-Authenticate header forwards it in the response."""
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
@app.get("/secure")
|
||||||
|
async def secure():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail="Not authenticated",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/secure")
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert response.headers.get("www-authenticate") == "Bearer"
|
||||||
|
|
||||||
|
def test_custom_openapi_schema(self):
|
||||||
|
"""Customises OpenAPI schema for 422 responses using named examples."""
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
class Item(BaseModel):
|
class Item(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
@@ -197,8 +538,128 @@ class TestInitExceptionsHandlers:
|
|||||||
post_op = openapi["paths"]["/items"]["post"]
|
post_op = openapi["paths"]["/items"]["post"]
|
||||||
assert "422" in post_op["responses"]
|
assert "422" in post_op["responses"]
|
||||||
resp_422 = post_op["responses"]["422"]
|
resp_422 = post_op["responses"]["422"]
|
||||||
example = resp_422["content"]["application/json"]["example"]
|
examples = resp_422["content"]["application/json"]["examples"]
|
||||||
assert example["error_code"] == "VAL-422"
|
assert "VAL-422" in examples
|
||||||
|
assert examples["VAL-422"]["value"]["error_code"] == "VAL-422"
|
||||||
|
|
||||||
|
def test_custom_openapi_preserves_app_metadata(self):
|
||||||
|
"""_patched_openapi preserves custom FastAPI app-level metadata."""
|
||||||
|
app = FastAPI(
|
||||||
|
title="My API",
|
||||||
|
version="2.0.0",
|
||||||
|
description="Custom description",
|
||||||
|
)
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
schema = app.openapi()
|
||||||
|
assert schema["info"]["title"] == "My API"
|
||||||
|
assert schema["info"]["version"] == "2.0.0"
|
||||||
|
|
||||||
|
def test_handles_response_validation_error(self):
|
||||||
|
"""Handles ResponseValidationError with a structured 422 response."""
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class CountResponse(BaseModel):
|
||||||
|
count: int
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
@app.get("/broken", response_model=CountResponse)
|
||||||
|
async def broken():
|
||||||
|
return {"count": "not-a-number"} # triggers ResponseValidationError
|
||||||
|
|
||||||
|
client = TestClient(app, raise_server_exceptions=False)
|
||||||
|
response = client.get("/broken")
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "FAIL"
|
||||||
|
assert data["error_code"] == "VAL-422"
|
||||||
|
assert "errors" in data["data"]
|
||||||
|
|
||||||
|
def test_handles_validation_error_with_non_standard_loc(self):
|
||||||
|
"""Validation error with empty loc tuple maps the field to 'root'."""
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
@app.get("/root-error")
|
||||||
|
async def root_error():
|
||||||
|
raise RequestValidationError(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "custom",
|
||||||
|
"loc": (),
|
||||||
|
"msg": "root level error",
|
||||||
|
"input": None,
|
||||||
|
"url": "",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/root-error")
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
data = response.json()
|
||||||
|
assert data["data"]["errors"][0]["field"] == "root"
|
||||||
|
|
||||||
|
def test_openapi_schema_cached_after_first_call(self):
|
||||||
|
"""app.openapi() returns the cached schema on subsequent calls."""
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
class Item(BaseModel):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
@app.post("/items")
|
||||||
|
async def create_item(item: Item):
|
||||||
|
return item
|
||||||
|
|
||||||
|
schema_first = app.openapi()
|
||||||
|
schema_second = app.openapi()
|
||||||
|
assert schema_first is schema_second
|
||||||
|
|
||||||
|
def test_openapi_skips_operations_without_422(self):
|
||||||
|
"""_patched_openapi leaves operations that have no 422 response unchanged."""
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
@app.get("/ping")
|
||||||
|
async def ping():
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
schema = app.openapi()
|
||||||
|
get_op = schema["paths"]["/ping"]["get"]
|
||||||
|
assert "422" not in get_op["responses"]
|
||||||
|
assert "200" in get_op["responses"]
|
||||||
|
|
||||||
|
def test_openapi_skips_non_dict_path_item_values(self):
|
||||||
|
"""_patched_openapi ignores non-dict values in path items (e.g. path-level parameters)."""
|
||||||
|
from fastapi_toolsets.exceptions.handler import _patched_openapi
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
def fake_openapi() -> dict:
|
||||||
|
return {
|
||||||
|
"paths": {
|
||||||
|
"/items": {
|
||||||
|
"parameters": [
|
||||||
|
{"name": "q", "in": "query"}
|
||||||
|
], # list, not a dict
|
||||||
|
"get": {"responses": {"200": {"description": "OK"}}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
schema = _patched_openapi(app, fake_openapi)
|
||||||
|
# The list value was skipped without error; the GET operation is intact
|
||||||
|
assert schema["paths"]["/items"]["parameters"] == [{"name": "q", "in": "query"}]
|
||||||
|
assert "422" not in schema["paths"]["/items"]["get"]["responses"]
|
||||||
|
|
||||||
|
|
||||||
class TestExceptionIntegration:
|
class TestExceptionIntegration:
|
||||||
@@ -263,3 +724,43 @@ class TestExceptionIntegration:
|
|||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == {"id": 1}
|
assert response.json() == {"id": 1}
|
||||||
|
|
||||||
|
|
||||||
|
class TestInvalidOrderFieldError:
|
||||||
|
"""Tests for InvalidOrderFieldError exception."""
|
||||||
|
|
||||||
|
def test_api_error_attributes(self):
|
||||||
|
"""InvalidOrderFieldError has correct api_error metadata."""
|
||||||
|
assert InvalidOrderFieldError.api_error.code == 422
|
||||||
|
assert InvalidOrderFieldError.api_error.err_code == "SORT-422"
|
||||||
|
assert InvalidOrderFieldError.api_error.msg == "Invalid Order Field"
|
||||||
|
|
||||||
|
def test_stores_field_and_valid_fields(self):
|
||||||
|
"""InvalidOrderFieldError stores field and valid_fields on the instance."""
|
||||||
|
error = InvalidOrderFieldError("unknown", ["name", "created_at"])
|
||||||
|
assert error.field == "unknown"
|
||||||
|
assert error.valid_fields == ["name", "created_at"]
|
||||||
|
|
||||||
|
def test_description_contains_field_and_valid_fields(self):
|
||||||
|
"""api_error.desc mentions the bad field and valid options."""
|
||||||
|
error = InvalidOrderFieldError("bad_field", ["name", "email"])
|
||||||
|
assert "bad_field" in error.api_error.desc
|
||||||
|
assert "name" in error.api_error.desc
|
||||||
|
assert "email" in error.api_error.desc
|
||||||
|
|
||||||
|
def test_handled_as_422_by_exception_handler(self):
|
||||||
|
"""init_exceptions_handlers turns InvalidOrderFieldError into a 422 response."""
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
@app.get("/items")
|
||||||
|
async def list_items():
|
||||||
|
raise InvalidOrderFieldError("bad", ["name"])
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/items")
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
data = response.json()
|
||||||
|
assert data["error_code"] == "SORT-422"
|
||||||
|
assert data["status"] == "FAIL"
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,57 +0,0 @@
|
|||||||
"""Tests for fastapi_toolsets.fixtures.utils."""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from fastapi_toolsets.fixtures import FixtureRegistry
|
|
||||||
from fastapi_toolsets.fixtures.utils import get_obj_by_attr
|
|
||||||
|
|
||||||
from .conftest import Role, User
|
|
||||||
|
|
||||||
registry = FixtureRegistry()
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
|
||||||
def roles() -> list[Role]:
|
|
||||||
return [
|
|
||||||
Role(id=1, name="admin"),
|
|
||||||
Role(id=2, name="user"),
|
|
||||||
Role(id=3, name="moderator"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register(depends_on=["roles"])
|
|
||||||
def users() -> list[User]:
|
|
||||||
return [
|
|
||||||
User(id=1, username="alice", email="alice@example.com", role_id=1),
|
|
||||||
User(id=2, username="bob", email="bob@example.com", role_id=1),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class TestGetObjByAttr:
|
|
||||||
"""Tests for get_obj_by_attr."""
|
|
||||||
|
|
||||||
def test_get_by_id(self):
|
|
||||||
"""Get an object by its id attribute."""
|
|
||||||
role = get_obj_by_attr(roles, "id", 1)
|
|
||||||
assert role.name == "admin"
|
|
||||||
|
|
||||||
def test_get_user_by_username(self):
|
|
||||||
"""Get a user by username."""
|
|
||||||
user = get_obj_by_attr(users, "username", "bob")
|
|
||||||
assert user.id == 2
|
|
||||||
assert user.email == "bob@example.com"
|
|
||||||
|
|
||||||
def test_returns_first_match(self):
|
|
||||||
"""Returns the first matching object when multiple could match."""
|
|
||||||
user = get_obj_by_attr(users, "role_id", 1)
|
|
||||||
assert user.username == "alice"
|
|
||||||
|
|
||||||
def test_no_match_raises_stop_iteration(self):
|
|
||||||
"""Raises StopIteration when no object matches."""
|
|
||||||
with pytest.raises(StopIteration):
|
|
||||||
get_obj_by_attr(roles, "name", "nonexistent")
|
|
||||||
|
|
||||||
def test_no_match_on_wrong_value_type(self):
|
|
||||||
"""Raises StopIteration when value type doesn't match."""
|
|
||||||
with pytest.raises(StopIteration):
|
|
||||||
get_obj_by_attr(roles, "id", "1")
|
|
||||||
210
tests/test_imports.py
Normal file
210
tests/test_imports.py
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
"""Tests for optional dependency import guards."""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from fastapi_toolsets._imports import require_extra
|
||||||
|
|
||||||
|
|
||||||
|
class TestRequireExtra:
|
||||||
|
"""Tests for the require_extra helper."""
|
||||||
|
|
||||||
|
def test_raises_import_error(self):
|
||||||
|
"""require_extra raises ImportError."""
|
||||||
|
with pytest.raises(ImportError):
|
||||||
|
require_extra(package="some_pkg", extra="some_extra")
|
||||||
|
|
||||||
|
def test_error_message_contains_package_name(self):
|
||||||
|
"""Error message mentions the missing package."""
|
||||||
|
with pytest.raises(ImportError, match="'prometheus_client'"):
|
||||||
|
require_extra(package="prometheus_client", extra="metrics")
|
||||||
|
|
||||||
|
def test_error_message_contains_install_instruction(self):
|
||||||
|
"""Error message contains the pip install command."""
|
||||||
|
with pytest.raises(
|
||||||
|
ImportError, match=r"pip install fastapi-toolsets\[metrics\]"
|
||||||
|
):
|
||||||
|
require_extra(package="prometheus_client", extra="metrics")
|
||||||
|
|
||||||
|
|
||||||
|
def _reload_without_package(module_path: str, blocked_packages: list[str]):
|
||||||
|
"""Reload a module while blocking specific package imports.
|
||||||
|
|
||||||
|
Removes the target module and its parents from sys.modules so they
|
||||||
|
get re-imported, and patches builtins.__import__ to raise ImportError
|
||||||
|
for *blocked_packages*.
|
||||||
|
"""
|
||||||
|
# Remove cached modules so they get re-imported
|
||||||
|
to_remove = [
|
||||||
|
key
|
||||||
|
for key in sys.modules
|
||||||
|
if key == module_path or key.startswith(module_path + ".")
|
||||||
|
]
|
||||||
|
saved = {}
|
||||||
|
for key in to_remove:
|
||||||
|
saved[key] = sys.modules.pop(key)
|
||||||
|
|
||||||
|
# Also remove parent package to force re-execution of __init__.py
|
||||||
|
parts = module_path.rsplit(".", 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
parent = parts[0]
|
||||||
|
parent_keys = [
|
||||||
|
key for key in sys.modules if key == parent or key.startswith(parent + ".")
|
||||||
|
]
|
||||||
|
for key in parent_keys:
|
||||||
|
if key not in saved:
|
||||||
|
saved[key] = sys.modules.pop(key)
|
||||||
|
|
||||||
|
original_import = (
|
||||||
|
__builtins__.__import__ if hasattr(__builtins__, "__import__") else __import__
|
||||||
|
)
|
||||||
|
|
||||||
|
def blocking_import(name, *args, **kwargs):
|
||||||
|
for blocked in blocked_packages:
|
||||||
|
if name == blocked or name.startswith(blocked + "."):
|
||||||
|
raise ImportError(f"Mocked: No module named '{name}'")
|
||||||
|
return original_import(name, *args, **kwargs)
|
||||||
|
|
||||||
|
return saved, blocking_import
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetricsImportGuard:
|
||||||
|
"""Tests for metrics module import guard when prometheus_client is missing."""
|
||||||
|
|
||||||
|
def test_registry_imports_without_prometheus(self):
|
||||||
|
"""Metric and MetricsRegistry are importable without prometheus_client."""
|
||||||
|
saved, blocking_import = _reload_without_package(
|
||||||
|
"fastapi_toolsets.metrics", ["prometheus_client"]
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with patch("builtins.__import__", side_effect=blocking_import):
|
||||||
|
mod = importlib.import_module("fastapi_toolsets.metrics")
|
||||||
|
# Registry types should be available (they're stdlib-only)
|
||||||
|
assert hasattr(mod, "Metric")
|
||||||
|
assert hasattr(mod, "MetricsRegistry")
|
||||||
|
finally:
|
||||||
|
# Restore original modules
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key.startswith("fastapi_toolsets.metrics"):
|
||||||
|
sys.modules.pop(key, None)
|
||||||
|
sys.modules.update(saved)
|
||||||
|
|
||||||
|
def test_init_metrics_stub_raises_without_prometheus(self):
|
||||||
|
"""init_metrics raises ImportError when prometheus_client is missing."""
|
||||||
|
saved, blocking_import = _reload_without_package(
|
||||||
|
"fastapi_toolsets.metrics", ["prometheus_client"]
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with patch("builtins.__import__", side_effect=blocking_import):
|
||||||
|
mod = importlib.import_module("fastapi_toolsets.metrics")
|
||||||
|
with pytest.raises(ImportError, match="prometheus_client"):
|
||||||
|
mod.init_metrics(None, None) # type: ignore[arg-type] # ty:ignore[invalid-argument-type]
|
||||||
|
finally:
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key.startswith("fastapi_toolsets.metrics"):
|
||||||
|
sys.modules.pop(key, None)
|
||||||
|
sys.modules.update(saved)
|
||||||
|
|
||||||
|
def test_init_metrics_works_with_prometheus(self):
|
||||||
|
"""init_metrics is the real function when prometheus_client is available."""
|
||||||
|
from fastapi_toolsets.metrics import init_metrics
|
||||||
|
|
||||||
|
# Should be the real function, not a stub
|
||||||
|
assert init_metrics.__module__ == "fastapi_toolsets.metrics.handler"
|
||||||
|
|
||||||
|
|
||||||
|
class TestPytestImportGuard:
|
||||||
|
"""Tests for pytest module import guard when dependencies are missing."""
|
||||||
|
|
||||||
|
def test_import_raises_without_pytest_package(self):
|
||||||
|
"""Importing fastapi_toolsets.pytest raises when pytest is missing."""
|
||||||
|
saved, blocking_import = _reload_without_package(
|
||||||
|
"fastapi_toolsets.pytest", ["pytest"]
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with patch("builtins.__import__", side_effect=blocking_import):
|
||||||
|
with pytest.raises(ImportError, match="pytest"):
|
||||||
|
importlib.import_module("fastapi_toolsets.pytest")
|
||||||
|
finally:
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key.startswith("fastapi_toolsets.pytest"):
|
||||||
|
sys.modules.pop(key, None)
|
||||||
|
sys.modules.update(saved)
|
||||||
|
|
||||||
|
def test_import_raises_without_httpx(self):
|
||||||
|
"""Importing fastapi_toolsets.pytest raises when httpx is missing."""
|
||||||
|
saved, blocking_import = _reload_without_package(
|
||||||
|
"fastapi_toolsets.pytest", ["httpx"]
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with patch("builtins.__import__", side_effect=blocking_import):
|
||||||
|
with pytest.raises(ImportError, match="httpx"):
|
||||||
|
importlib.import_module("fastapi_toolsets.pytest")
|
||||||
|
finally:
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key.startswith("fastapi_toolsets.pytest"):
|
||||||
|
sys.modules.pop(key, None)
|
||||||
|
sys.modules.update(saved)
|
||||||
|
|
||||||
|
def test_all_exports_available_with_deps(self):
|
||||||
|
"""All expected exports are available when deps are installed."""
|
||||||
|
from fastapi_toolsets.pytest import (
|
||||||
|
cleanup_tables,
|
||||||
|
create_async_client,
|
||||||
|
create_db_session,
|
||||||
|
create_worker_database,
|
||||||
|
register_fixtures,
|
||||||
|
worker_database_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert callable(register_fixtures)
|
||||||
|
assert callable(create_async_client)
|
||||||
|
assert callable(create_db_session)
|
||||||
|
assert callable(create_worker_database)
|
||||||
|
assert callable(worker_database_url)
|
||||||
|
assert callable(cleanup_tables)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCliImportGuard:
|
||||||
|
"""Tests for CLI module import guard when typer is missing."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"expected_match",
|
||||||
|
[
|
||||||
|
"typer",
|
||||||
|
r"pip install fastapi-toolsets\[cli\]",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_import_raises_without_typer(self, expected_match):
|
||||||
|
"""Importing cli.app raises when typer is missing, with an informative error message."""
|
||||||
|
saved, blocking_import = _reload_without_package(
|
||||||
|
"fastapi_toolsets.cli.app", ["typer"]
|
||||||
|
)
|
||||||
|
# Also remove cli.config since it imports typer too
|
||||||
|
config_keys = [
|
||||||
|
k for k in sys.modules if k.startswith("fastapi_toolsets.cli.config")
|
||||||
|
]
|
||||||
|
for key in config_keys:
|
||||||
|
if key not in saved:
|
||||||
|
saved[key] = sys.modules.pop(key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with patch("builtins.__import__", side_effect=blocking_import):
|
||||||
|
with pytest.raises(ImportError, match=expected_match):
|
||||||
|
importlib.import_module("fastapi_toolsets.cli.app")
|
||||||
|
finally:
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key.startswith("fastapi_toolsets.cli.app") or key.startswith(
|
||||||
|
"fastapi_toolsets.cli.config"
|
||||||
|
):
|
||||||
|
sys.modules.pop(key, None)
|
||||||
|
sys.modules.update(saved)
|
||||||
|
|
||||||
|
def test_async_command_imports_without_typer(self):
|
||||||
|
"""async_command is importable without typer (stdlib only)."""
|
||||||
|
from fastapi_toolsets.cli import async_command
|
||||||
|
|
||||||
|
assert callable(async_command)
|
||||||
118
tests/test_logger.py
Normal file
118
tests/test_logger.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from fastapi_toolsets.logger import (
|
||||||
|
DEFAULT_FORMAT,
|
||||||
|
UVICORN_LOGGERS,
|
||||||
|
configure_logging,
|
||||||
|
get_logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset_loggers():
|
||||||
|
"""Reset the root and uvicorn loggers after each test."""
|
||||||
|
yield
|
||||||
|
root = logging.getLogger()
|
||||||
|
root.handlers.clear()
|
||||||
|
root.setLevel(logging.WARNING)
|
||||||
|
for name in UVICORN_LOGGERS:
|
||||||
|
uv = logging.getLogger(name)
|
||||||
|
uv.handlers.clear()
|
||||||
|
uv.setLevel(logging.NOTSET)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigureLogging:
|
||||||
|
def test_sets_up_handler_and_format(self):
|
||||||
|
logger = configure_logging()
|
||||||
|
|
||||||
|
assert len(logger.handlers) == 1
|
||||||
|
handler = logger.handlers[0]
|
||||||
|
assert isinstance(handler, logging.StreamHandler)
|
||||||
|
assert handler.stream is sys.stdout
|
||||||
|
assert handler.formatter is not None
|
||||||
|
assert handler.formatter._fmt == DEFAULT_FORMAT
|
||||||
|
|
||||||
|
def test_default_level_is_info(self):
|
||||||
|
logger = configure_logging()
|
||||||
|
|
||||||
|
assert logger.level == logging.INFO
|
||||||
|
|
||||||
|
def test_custom_level_string(self):
|
||||||
|
logger = configure_logging(level="DEBUG")
|
||||||
|
|
||||||
|
assert logger.level == logging.DEBUG
|
||||||
|
|
||||||
|
def test_custom_level_int(self):
|
||||||
|
logger = configure_logging(level=logging.WARNING)
|
||||||
|
|
||||||
|
assert logger.level == logging.WARNING
|
||||||
|
|
||||||
|
def test_custom_format(self):
|
||||||
|
custom_fmt = "%(levelname)s: %(message)s"
|
||||||
|
logger = configure_logging(fmt=custom_fmt)
|
||||||
|
|
||||||
|
handler = logger.handlers[0]
|
||||||
|
assert handler.formatter is not None
|
||||||
|
assert handler.formatter._fmt == custom_fmt
|
||||||
|
|
||||||
|
def test_named_logger(self):
|
||||||
|
logger = configure_logging(logger_name="myapp")
|
||||||
|
|
||||||
|
assert logger.name == "myapp"
|
||||||
|
assert len(logger.handlers) == 1
|
||||||
|
|
||||||
|
def test_default_configures_root_logger(self):
|
||||||
|
logger = configure_logging()
|
||||||
|
|
||||||
|
assert logger is logging.getLogger()
|
||||||
|
|
||||||
|
def test_idempotent_no_duplicate_handlers(self):
|
||||||
|
configure_logging()
|
||||||
|
configure_logging()
|
||||||
|
logger = configure_logging()
|
||||||
|
|
||||||
|
assert len(logger.handlers) == 1
|
||||||
|
|
||||||
|
def test_configures_uvicorn_loggers(self):
|
||||||
|
configure_logging(level="DEBUG")
|
||||||
|
|
||||||
|
for name in UVICORN_LOGGERS:
|
||||||
|
uv_logger = logging.getLogger(name)
|
||||||
|
assert len(uv_logger.handlers) == 1
|
||||||
|
assert uv_logger.level == logging.DEBUG
|
||||||
|
handler = uv_logger.handlers[0]
|
||||||
|
assert handler.formatter is not None
|
||||||
|
assert handler.formatter._fmt == DEFAULT_FORMAT
|
||||||
|
|
||||||
|
def test_returns_configured_logger(self):
|
||||||
|
logger = configure_logging(logger_name="test.return")
|
||||||
|
|
||||||
|
assert isinstance(logger, logging.Logger)
|
||||||
|
assert logger.name == "test.return"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetLogger:
|
||||||
|
def test_returns_named_logger(self):
|
||||||
|
logger = get_logger("myapp.services")
|
||||||
|
|
||||||
|
assert isinstance(logger, logging.Logger)
|
||||||
|
assert logger.name == "myapp.services"
|
||||||
|
|
||||||
|
def test_returns_root_logger_when_none(self):
|
||||||
|
logger = get_logger(None)
|
||||||
|
|
||||||
|
assert logger is logging.getLogger()
|
||||||
|
|
||||||
|
def test_defaults_to_caller_module_name(self):
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
assert logger.name == __name__
|
||||||
|
|
||||||
|
def test_same_name_returns_same_logger(self):
|
||||||
|
a = get_logger("myapp")
|
||||||
|
b = get_logger("myapp")
|
||||||
|
|
||||||
|
assert a is b
|
||||||
549
tests/test_metrics.py
Normal file
549
tests/test_metrics.py
Normal file
@@ -0,0 +1,549 @@
|
|||||||
|
"""Tests for fastapi_toolsets.metrics module."""
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from prometheus_client import REGISTRY, CollectorRegistry, Counter, Gauge
|
||||||
|
|
||||||
|
from fastapi_toolsets.metrics import Metric, MetricsRegistry, init_metrics
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _clean_prometheus_registry():
|
||||||
|
"""Unregister test collectors from the global registry after each test."""
|
||||||
|
yield
|
||||||
|
collectors = list(REGISTRY._names_to_collectors.values())
|
||||||
|
for collector in collectors:
|
||||||
|
try:
|
||||||
|
REGISTRY.unregister(collector)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetric:
|
||||||
|
"""Tests for Metric dataclass."""
|
||||||
|
|
||||||
|
def test_default_collect_is_false(self):
|
||||||
|
"""Default collect is False (provider mode)."""
|
||||||
|
definition = Metric(name="test", func=lambda: None)
|
||||||
|
assert definition.collect is False
|
||||||
|
|
||||||
|
def test_collect_true(self):
|
||||||
|
"""Collect can be set to True (collector mode)."""
|
||||||
|
definition = Metric(name="test", func=lambda: None, collect=True)
|
||||||
|
assert definition.collect is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetricsRegistry:
|
||||||
|
"""Tests for MetricsRegistry class."""
|
||||||
|
|
||||||
|
def test_register_with_decorator(self):
|
||||||
|
"""Register metric with bare decorator."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def my_counter():
|
||||||
|
return Counter("test_counter", "A test counter")
|
||||||
|
|
||||||
|
names = [m.name for m in registry.get_all()]
|
||||||
|
assert "my_counter" in names
|
||||||
|
|
||||||
|
def test_register_with_custom_name(self):
|
||||||
|
"""Register metric with custom name."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register(name="custom_name")
|
||||||
|
def my_counter():
|
||||||
|
return Counter("test_counter_2", "A test counter")
|
||||||
|
|
||||||
|
definition = registry.get_all()[0]
|
||||||
|
assert definition.name == "custom_name"
|
||||||
|
|
||||||
|
def test_register_as_collector(self):
|
||||||
|
"""Register metric with collect=True."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def collect_something():
|
||||||
|
pass
|
||||||
|
|
||||||
|
definition = registry.get_all()[0]
|
||||||
|
assert definition.collect is True
|
||||||
|
|
||||||
|
def test_register_preserves_function(self):
|
||||||
|
"""Decorator returns the original function unchanged."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
def my_func():
|
||||||
|
return "original"
|
||||||
|
|
||||||
|
result = registry.register(my_func)
|
||||||
|
assert result is my_func
|
||||||
|
assert result() == "original"
|
||||||
|
|
||||||
|
def test_register_parameterized_preserves_function(self):
|
||||||
|
"""Parameterized decorator returns the original function unchanged."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
def my_func():
|
||||||
|
return "original"
|
||||||
|
|
||||||
|
result = registry.register(name="custom")(my_func)
|
||||||
|
assert result is my_func
|
||||||
|
assert result() == "original"
|
||||||
|
|
||||||
|
def test_get_all(self):
|
||||||
|
"""Get all registered metrics."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def metric_a():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def metric_b():
|
||||||
|
pass
|
||||||
|
|
||||||
|
names = {m.name for m in registry.get_all()}
|
||||||
|
assert names == {"metric_a", "metric_b"}
|
||||||
|
|
||||||
|
def test_get_providers(self):
|
||||||
|
"""Get only provider metrics (collect=False)."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def provider():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def collector():
|
||||||
|
pass
|
||||||
|
|
||||||
|
providers = registry.get_providers()
|
||||||
|
assert len(providers) == 1
|
||||||
|
assert providers[0].name == "provider"
|
||||||
|
|
||||||
|
def test_get_collectors(self):
|
||||||
|
"""Get only collector metrics (collect=True)."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def provider():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def collector():
|
||||||
|
pass
|
||||||
|
|
||||||
|
collectors = registry.get_collectors()
|
||||||
|
assert len(collectors) == 1
|
||||||
|
assert collectors[0].name == "collector"
|
||||||
|
|
||||||
|
def test_register_overwrites_same_name(self):
|
||||||
|
"""Registering with the same name overwrites the previous entry."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register(name="metric")
|
||||||
|
def first():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@registry.register(name="metric")
|
||||||
|
def second():
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert len(registry.get_all()) == 1
|
||||||
|
assert registry.get_all()[0].func is second
|
||||||
|
|
||||||
|
|
||||||
|
class TestGet:
|
||||||
|
"""Tests for MetricsRegistry.get method."""
|
||||||
|
|
||||||
|
def test_get_returns_instance_after_init(self):
|
||||||
|
"""get() returns the metric instance stored by init_metrics."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def my_gauge():
|
||||||
|
return Gauge("get_test_gauge", "A test gauge")
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
instance = registry.get("my_gauge")
|
||||||
|
assert isinstance(instance, Gauge)
|
||||||
|
|
||||||
|
def test_get_raises_for_registered_but_not_initialized(self):
|
||||||
|
"""get() raises KeyError with an informative message when init_metrics was not called."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def my_counter():
|
||||||
|
return Counter("get_uninit_counter", "A counter")
|
||||||
|
|
||||||
|
with pytest.raises(KeyError, match="not been initialized yet"):
|
||||||
|
registry.get("my_counter")
|
||||||
|
|
||||||
|
def test_get_raises_for_unknown_name(self):
|
||||||
|
"""get() raises KeyError when the metric name is not registered at all."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
with pytest.raises(KeyError, match="Unknown metric"):
|
||||||
|
registry.get("nonexistent")
|
||||||
|
|
||||||
|
|
||||||
|
class TestIncludeRegistry:
|
||||||
|
"""Tests for MetricsRegistry.include_registry method."""
|
||||||
|
|
||||||
|
def test_include_empty_registry(self):
|
||||||
|
"""Include an empty registry does nothing."""
|
||||||
|
main = MetricsRegistry()
|
||||||
|
other = MetricsRegistry()
|
||||||
|
|
||||||
|
@main.register
|
||||||
|
def metric_a():
|
||||||
|
pass
|
||||||
|
|
||||||
|
main.include_registry(other)
|
||||||
|
assert len(main.get_all()) == 1
|
||||||
|
|
||||||
|
def test_include_registry_adds_metrics(self):
|
||||||
|
"""Include registry adds all metrics from the other registry."""
|
||||||
|
main = MetricsRegistry()
|
||||||
|
other = MetricsRegistry()
|
||||||
|
|
||||||
|
@main.register
|
||||||
|
def metric_a():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@other.register
|
||||||
|
def metric_b():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@other.register
|
||||||
|
def metric_c():
|
||||||
|
pass
|
||||||
|
|
||||||
|
main.include_registry(other)
|
||||||
|
names = {m.name for m in main.get_all()}
|
||||||
|
assert names == {"metric_a", "metric_b", "metric_c"}
|
||||||
|
|
||||||
|
def test_include_registry_preserves_collect_flag(self):
|
||||||
|
"""Include registry preserves the collect flag."""
|
||||||
|
main = MetricsRegistry()
|
||||||
|
other = MetricsRegistry()
|
||||||
|
|
||||||
|
@other.register(collect=True)
|
||||||
|
def collector():
|
||||||
|
pass
|
||||||
|
|
||||||
|
main.include_registry(other)
|
||||||
|
assert main.get_all()[0].collect is True
|
||||||
|
|
||||||
|
def test_include_registry_raises_on_duplicate(self):
|
||||||
|
"""Include registry raises ValueError on duplicate metric names."""
|
||||||
|
main = MetricsRegistry()
|
||||||
|
other = MetricsRegistry()
|
||||||
|
|
||||||
|
@main.register(name="metric")
|
||||||
|
def metric_main():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@other.register(name="metric")
|
||||||
|
def metric_other():
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="already exists"):
|
||||||
|
main.include_registry(other)
|
||||||
|
|
||||||
|
def test_include_multiple_registries(self):
|
||||||
|
"""Include multiple registries sequentially."""
|
||||||
|
main = MetricsRegistry()
|
||||||
|
sub1 = MetricsRegistry()
|
||||||
|
sub2 = MetricsRegistry()
|
||||||
|
|
||||||
|
@main.register
|
||||||
|
def base():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@sub1.register
|
||||||
|
def sub1_metric():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@sub2.register
|
||||||
|
def sub2_metric():
|
||||||
|
pass
|
||||||
|
|
||||||
|
main.include_registry(sub1)
|
||||||
|
main.include_registry(sub2)
|
||||||
|
|
||||||
|
names = {m.name for m in main.get_all()}
|
||||||
|
assert names == {"base", "sub1_metric", "sub2_metric"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestInitMetrics:
|
||||||
|
"""Tests for init_metrics function."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def metrics_client(self):
|
||||||
|
"""Create a FastAPI app with MetricsRegistry and return a TestClient."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
init_metrics(app, registry)
|
||||||
|
client = TestClient(app)
|
||||||
|
yield client
|
||||||
|
client.close()
|
||||||
|
|
||||||
|
def test_returns_app(self):
|
||||||
|
"""Returns the FastAPI app."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
result = init_metrics(app, registry)
|
||||||
|
assert result is app
|
||||||
|
|
||||||
|
def test_metrics_endpoint_responds(self, metrics_client):
|
||||||
|
"""The /metrics endpoint returns 200."""
|
||||||
|
response = metrics_client.get("/metrics")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_metrics_endpoint_content_type(self, metrics_client):
|
||||||
|
"""The /metrics endpoint returns prometheus content type."""
|
||||||
|
response = metrics_client.get("/metrics")
|
||||||
|
assert "text/plain" in response.headers["content-type"]
|
||||||
|
|
||||||
|
def test_custom_path(self):
|
||||||
|
"""Custom path is used for the metrics endpoint."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
init_metrics(app, registry, path="/custom-metrics")
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
assert client.get("/custom-metrics").status_code == 200
|
||||||
|
assert client.get("/metrics").status_code == 404
|
||||||
|
|
||||||
|
def test_providers_called_at_init(self):
|
||||||
|
"""Provider functions are called once at init time."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
mock = MagicMock()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def my_provider():
|
||||||
|
mock()
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
mock.assert_called_once()
|
||||||
|
|
||||||
|
def test_collectors_called_on_scrape(self):
|
||||||
|
"""Collector functions are called on each scrape."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
mock = MagicMock()
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def my_collector():
|
||||||
|
mock()
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
client.get("/metrics")
|
||||||
|
client.get("/metrics")
|
||||||
|
|
||||||
|
assert mock.call_count == 2
|
||||||
|
|
||||||
|
def test_collectors_not_called_at_init(self):
|
||||||
|
"""Collector functions are not called at init time."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
mock = MagicMock()
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def my_collector():
|
||||||
|
mock()
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
mock.assert_not_called()
|
||||||
|
|
||||||
|
def test_async_collectors_called_on_scrape(self):
|
||||||
|
"""Async collector functions are awaited on each scrape."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
mock = AsyncMock()
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
async def my_async_collector():
|
||||||
|
await mock()
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
client.get("/metrics")
|
||||||
|
client.get("/metrics")
|
||||||
|
|
||||||
|
assert mock.call_count == 2
|
||||||
|
|
||||||
|
def test_mixed_sync_and_async_collectors(self):
|
||||||
|
"""Both sync and async collectors are called on scrape."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
sync_mock = MagicMock()
|
||||||
|
async_mock = AsyncMock()
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def sync_collector():
|
||||||
|
sync_mock()
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
async def async_collector():
|
||||||
|
await async_mock()
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
client.get("/metrics")
|
||||||
|
|
||||||
|
sync_mock.assert_called_once()
|
||||||
|
async_mock.assert_called_once()
|
||||||
|
|
||||||
|
def test_registered_metrics_appear_in_output(self):
|
||||||
|
"""Metrics created by providers appear in /metrics output."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def my_gauge():
|
||||||
|
g = Gauge("test_gauge_value", "A test gauge")
|
||||||
|
g.set(42)
|
||||||
|
return g
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/metrics")
|
||||||
|
|
||||||
|
assert b"test_gauge_value" in response.content
|
||||||
|
assert b"42.0" in response.content
|
||||||
|
|
||||||
|
def test_endpoint_not_in_openapi_schema(self):
|
||||||
|
"""The /metrics endpoint is not included in the OpenAPI schema."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
schema = app.openapi()
|
||||||
|
assert "/metrics" not in schema.get("paths", {})
|
||||||
|
|
||||||
|
|
||||||
|
class TestMultiProcessMode:
|
||||||
|
"""Tests for multi-process Prometheus mode."""
|
||||||
|
|
||||||
|
def test_multiprocess_with_env_var(self, monkeypatch):
|
||||||
|
"""Multi-process mode works when PROMETHEUS_MULTIPROC_DIR is set."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
monkeypatch.setenv("PROMETHEUS_MULTIPROC_DIR", tmpdir)
|
||||||
|
# Use a separate registry to avoid conflicts with default
|
||||||
|
prom_registry = CollectorRegistry()
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def mp_counter():
|
||||||
|
return Counter(
|
||||||
|
"mp_test_counter",
|
||||||
|
"A multiprocess counter",
|
||||||
|
registry=prom_registry,
|
||||||
|
)
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/metrics")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_single_process_without_env_var(self, monkeypatch):
|
||||||
|
"""Single-process mode when PROMETHEUS_MULTIPROC_DIR is not set."""
|
||||||
|
monkeypatch.delenv("PROMETHEUS_MULTIPROC_DIR", raising=False)
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def sp_gauge():
|
||||||
|
g = Gauge("sp_test_gauge", "A single-process gauge")
|
||||||
|
g.set(99)
|
||||||
|
return g
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/metrics")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert b"sp_test_gauge" in response.content
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetricsIntegration:
|
||||||
|
"""Integration tests for the metrics module."""
|
||||||
|
|
||||||
|
def test_full_workflow(self):
|
||||||
|
"""Full workflow: registry, providers, collectors, endpoint."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
call_count = {"value": 0}
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def request_counter():
|
||||||
|
return Counter(
|
||||||
|
"integration_requests_total",
|
||||||
|
"Total requests",
|
||||||
|
["method"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def collect_uptime():
|
||||||
|
call_count["value"] += 1
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
response = client.get("/metrics")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert b"integration_requests_total" in response.content
|
||||||
|
assert call_count["value"] == 1
|
||||||
|
|
||||||
|
response = client.get("/metrics")
|
||||||
|
assert call_count["value"] == 2
|
||||||
|
|
||||||
|
def test_multiple_registries_merged(self):
|
||||||
|
"""Multiple registries can be merged and used together."""
|
||||||
|
app = FastAPI()
|
||||||
|
main = MetricsRegistry()
|
||||||
|
sub = MetricsRegistry()
|
||||||
|
|
||||||
|
@main.register
|
||||||
|
def main_gauge():
|
||||||
|
g = Gauge("main_gauge_val", "Main gauge")
|
||||||
|
g.set(1)
|
||||||
|
return g
|
||||||
|
|
||||||
|
@sub.register
|
||||||
|
def sub_gauge():
|
||||||
|
g = Gauge("sub_gauge_val", "Sub gauge")
|
||||||
|
g.set(2)
|
||||||
|
return g
|
||||||
|
|
||||||
|
main.include_registry(sub)
|
||||||
|
init_metrics(app, main)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/metrics")
|
||||||
|
|
||||||
|
assert b"main_gauge_val" in response.content
|
||||||
|
assert b"sub_gauge_val" in response.content
|
||||||
1796
tests/test_models.py
Normal file
1796
tests/test_models.py
Normal file
File diff suppressed because it is too large
Load Diff
518
tests/test_pytest.py
Normal file
518
tests/test_pytest.py
Normal file
@@ -0,0 +1,518 @@
|
|||||||
|
"""Tests for fastapi_toolsets.pytest module."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import Depends, FastAPI
|
||||||
|
from httpx import AsyncClient
|
||||||
|
from sqlalchemy import select, text
|
||||||
|
from sqlalchemy.engine import make_url
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
from fastapi_toolsets.db import get_transaction
|
||||||
|
from fastapi_toolsets.fixtures import Context, FixtureRegistry
|
||||||
|
from fastapi_toolsets.pytest import (
|
||||||
|
create_async_client,
|
||||||
|
create_db_session,
|
||||||
|
create_worker_database,
|
||||||
|
register_fixtures,
|
||||||
|
worker_database_url,
|
||||||
|
)
|
||||||
|
from fastapi_toolsets.pytest.utils import _get_xdist_worker
|
||||||
|
|
||||||
|
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
|
||||||
|
|
||||||
|
test_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
# Fixed UUIDs for test fixtures to allow consistent assertions
|
||||||
|
ROLE_ADMIN_ID = uuid.UUID("00000000-0000-0000-0000-000000001000")
|
||||||
|
ROLE_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000001001")
|
||||||
|
USER_ADMIN_ID = uuid.UUID("00000000-0000-0000-0000-000000002000")
|
||||||
|
USER_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000002001")
|
||||||
|
USER_EXTRA_ID = uuid.UUID("00000000-0000-0000-0000-000000002002")
|
||||||
|
|
||||||
|
|
||||||
|
@test_registry.register(contexts=[Context.BASE])
|
||||||
|
def roles() -> list[Role]:
|
||||||
|
return [
|
||||||
|
Role(id=ROLE_ADMIN_ID, name="plugin_admin"),
|
||||||
|
Role(id=ROLE_USER_ID, name="plugin_user"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@test_registry.register(depends_on=["roles"], contexts=[Context.BASE])
|
||||||
|
def users() -> list[User]:
|
||||||
|
return [
|
||||||
|
User(
|
||||||
|
id=USER_ADMIN_ID,
|
||||||
|
username="plugin_admin",
|
||||||
|
email="padmin@test.com",
|
||||||
|
role_id=ROLE_ADMIN_ID,
|
||||||
|
),
|
||||||
|
User(
|
||||||
|
id=USER_USER_ID,
|
||||||
|
username="plugin_user",
|
||||||
|
email="puser@test.com",
|
||||||
|
role_id=ROLE_USER_ID,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@test_registry.register(depends_on=["users"], contexts=[Context.TESTING])
|
||||||
|
def extra_users() -> list[User]:
|
||||||
|
return [
|
||||||
|
User(
|
||||||
|
id=USER_EXTRA_ID,
|
||||||
|
username="plugin_extra",
|
||||||
|
email="pextra@test.com",
|
||||||
|
role_id=ROLE_USER_ID,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
register_fixtures(test_registry, globals())
|
||||||
|
|
||||||
|
|
||||||
|
class TestRegisterFixtures:
|
||||||
|
"""Tests for register_fixtures function."""
|
||||||
|
|
||||||
|
def test_creates_fixtures_in_namespace(self):
|
||||||
|
"""Fixtures are created in the namespace."""
|
||||||
|
assert "fixture_roles" in globals()
|
||||||
|
assert "fixture_users" in globals()
|
||||||
|
assert "fixture_extra_users" in globals()
|
||||||
|
|
||||||
|
def test_fixtures_are_callable(self):
|
||||||
|
"""Created fixtures are callable."""
|
||||||
|
assert callable(globals()["fixture_roles"])
|
||||||
|
assert callable(globals()["fixture_users"])
|
||||||
|
|
||||||
|
|
||||||
|
class TestGeneratedFixtures:
|
||||||
|
"""Tests for the generated pytest fixtures."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_fixture_loads_data(
|
||||||
|
self, db_session: AsyncSession, fixture_roles: list[Role]
|
||||||
|
):
|
||||||
|
"""Fixture loads data into database and returns it."""
|
||||||
|
assert len(fixture_roles) == 2
|
||||||
|
assert fixture_roles[0].name == "plugin_admin"
|
||||||
|
assert fixture_roles[1].name == "plugin_user"
|
||||||
|
|
||||||
|
# Verify data is in database
|
||||||
|
count = await RoleCrud.count(db_session)
|
||||||
|
assert count == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_fixture_with_dependency(
|
||||||
|
self, db_session: AsyncSession, fixture_users: list[User]
|
||||||
|
):
|
||||||
|
"""Fixture with dependency loads parent fixture first."""
|
||||||
|
# fixture_users depends on fixture_roles
|
||||||
|
# Both should be loaded
|
||||||
|
assert len(fixture_users) == 2
|
||||||
|
|
||||||
|
# Roles should also be in database
|
||||||
|
roles_count = await RoleCrud.count(db_session)
|
||||||
|
assert roles_count == 2
|
||||||
|
|
||||||
|
# Users should be in database
|
||||||
|
users_count = await UserCrud.count(db_session)
|
||||||
|
assert users_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_fixture_returns_models(
|
||||||
|
self, db_session: AsyncSession, fixture_users: list[User]
|
||||||
|
):
|
||||||
|
"""Fixture returns actual model instances."""
|
||||||
|
user = fixture_users[0]
|
||||||
|
assert isinstance(user, User)
|
||||||
|
assert user.id == USER_ADMIN_ID
|
||||||
|
assert user.username == "plugin_admin"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_fixture_relationships_work(
|
||||||
|
self, db_session: AsyncSession, fixture_users: list[User]
|
||||||
|
):
|
||||||
|
"""Loaded fixtures have working relationships."""
|
||||||
|
# Load user with role relationship
|
||||||
|
user = await UserCrud.get(
|
||||||
|
db_session,
|
||||||
|
[User.id == USER_ADMIN_ID],
|
||||||
|
load_options=[selectinload(User.role)],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert user.role is not None
|
||||||
|
assert user.role.name == "plugin_admin"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_chained_dependencies(
|
||||||
|
self, db_session: AsyncSession, fixture_extra_users: list[User]
|
||||||
|
):
|
||||||
|
"""Chained dependencies are resolved correctly."""
|
||||||
|
# fixture_extra_users -> fixture_users -> fixture_roles
|
||||||
|
assert len(fixture_extra_users) == 1
|
||||||
|
|
||||||
|
# All fixtures should be loaded
|
||||||
|
roles_count = await RoleCrud.count(db_session)
|
||||||
|
users_count = await UserCrud.count(db_session)
|
||||||
|
|
||||||
|
assert roles_count == 2
|
||||||
|
assert users_count == 3 # 2 from users + 1 from extra_users
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_can_query_loaded_data(
|
||||||
|
self, db_session: AsyncSession, fixture_users: list[User]
|
||||||
|
):
|
||||||
|
"""Can query the loaded fixture data."""
|
||||||
|
# Get all users loaded by fixture
|
||||||
|
users = await UserCrud.get_multi(
|
||||||
|
db_session,
|
||||||
|
order_by=User.username,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(users) == 2
|
||||||
|
assert users[0].username == "plugin_admin"
|
||||||
|
assert users[1].username == "plugin_user"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_multiple_fixtures_in_same_test(
|
||||||
|
self,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
fixture_roles: list[Role],
|
||||||
|
fixture_users: list[User],
|
||||||
|
):
|
||||||
|
"""Multiple fixtures can be used in the same test."""
|
||||||
|
assert len(fixture_roles) == 2
|
||||||
|
assert len(fixture_users) == 2
|
||||||
|
|
||||||
|
# Both should be in database
|
||||||
|
roles = await RoleCrud.get_multi(db_session)
|
||||||
|
users = await UserCrud.get_multi(db_session)
|
||||||
|
|
||||||
|
assert len(roles) == 2
|
||||||
|
assert len(users) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateAsyncClient:
|
||||||
|
"""Tests for create_async_client helper."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_creates_working_client(self):
|
||||||
|
"""Client can make requests to the app."""
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health():
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
async with create_async_client(app) as client:
|
||||||
|
assert isinstance(client, AsyncClient)
|
||||||
|
response = await client.get("/health")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"status": "ok"}
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_custom_base_url(self):
|
||||||
|
"""Client uses custom base URL."""
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/test")
|
||||||
|
async def test_endpoint():
|
||||||
|
return {"url": "test"}
|
||||||
|
|
||||||
|
async with create_async_client(app, base_url="http://custom") as client:
|
||||||
|
assert str(client.base_url) == "http://custom"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_client_closes_properly(self):
|
||||||
|
"""Client is properly closed after context exit."""
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
async with create_async_client(app) as client:
|
||||||
|
client_ref = client
|
||||||
|
|
||||||
|
assert client_ref.is_closed
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_dependency_overrides_applied_and_cleaned(self):
|
||||||
|
"""Dependency overrides are applied during the context and removed after."""
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
async def original_dep() -> str:
|
||||||
|
return "original"
|
||||||
|
|
||||||
|
async def override_dep() -> str:
|
||||||
|
return "overridden"
|
||||||
|
|
||||||
|
@app.get("/dep")
|
||||||
|
async def dep_endpoint(value: str = Depends(original_dep)):
|
||||||
|
return {"value": value}
|
||||||
|
|
||||||
|
async with create_async_client(
|
||||||
|
app, dependency_overrides={original_dep: override_dep}
|
||||||
|
) as client:
|
||||||
|
response = await client.get("/dep")
|
||||||
|
assert response.json() == {"value": "overridden"}
|
||||||
|
|
||||||
|
# Overrides should be cleaned up
|
||||||
|
assert original_dep not in app.dependency_overrides
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateDbSession:
|
||||||
|
"""Tests for create_db_session helper."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_creates_working_session(self):
|
||||||
|
"""Session can perform database operations."""
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
async with create_db_session(DATABASE_URL, Base) as session:
|
||||||
|
assert isinstance(session, AsyncSession)
|
||||||
|
|
||||||
|
role = Role(id=role_id, name="test_helper_role")
|
||||||
|
session.add(role)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
result = await session.execute(select(Role).where(Role.id == role_id))
|
||||||
|
fetched = result.scalar_one()
|
||||||
|
assert fetched.name == "test_helper_role"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_tables_created_before_session(self):
|
||||||
|
"""Tables exist when session is yielded."""
|
||||||
|
async with create_db_session(DATABASE_URL, Base) as session:
|
||||||
|
# Should not raise - tables exist
|
||||||
|
result = await session.execute(select(Role))
|
||||||
|
assert result.all() == []
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_tables_dropped_after_session(self):
|
||||||
|
"""Tables are dropped after session closes when drop_tables=True."""
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||||
|
role = Role(id=role_id, name="will_be_dropped")
|
||||||
|
session.add(role)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Verify tables were dropped by creating new session
|
||||||
|
async with create_db_session(DATABASE_URL, Base) as session:
|
||||||
|
result = await session.execute(select(Role))
|
||||||
|
assert result.all() == []
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_tables_preserved_when_drop_disabled(self):
|
||||||
|
"""Tables are preserved when drop_tables=False."""
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
|
||||||
|
role = Role(id=role_id, name="preserved_role")
|
||||||
|
session.add(role)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Create another session without dropping
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
|
||||||
|
result = await session.execute(select(Role).where(Role.id == role_id))
|
||||||
|
fetched = result.scalar_one_or_none()
|
||||||
|
assert fetched is not None
|
||||||
|
assert fetched.name == "preserved_role"
|
||||||
|
|
||||||
|
# Cleanup: drop tables manually
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_cleanup_truncates_tables(self):
|
||||||
|
"""Tables are truncated after session closes when cleanup=True."""
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
async with create_db_session(
|
||||||
|
DATABASE_URL, Base, cleanup=True, drop_tables=False
|
||||||
|
) as session:
|
||||||
|
role = Role(id=role_id, name="will_be_cleaned")
|
||||||
|
session.add(role)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Data should have been truncated, but tables still exist
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||||
|
result = await session.execute(select(Role))
|
||||||
|
assert result.all() == []
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_get_transaction_commits_visible_to_separate_session(self):
|
||||||
|
"""Data written via get_transaction() is committed and visible to other sessions."""
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
|
||||||
|
# Simulate what _create_fixture_function does: insert via get_transaction
|
||||||
|
# with no explicit commit afterward.
|
||||||
|
async with get_transaction(session):
|
||||||
|
role = Role(id=role_id, name="visible_to_other_session")
|
||||||
|
session.add(role)
|
||||||
|
|
||||||
|
# The data must have been committed (begin/commit, not a savepoint),
|
||||||
|
# so a separate engine/session can read it.
|
||||||
|
other_engine = create_async_engine(DATABASE_URL, echo=False)
|
||||||
|
try:
|
||||||
|
other_session_maker = async_sessionmaker(
|
||||||
|
other_engine, expire_on_commit=False
|
||||||
|
)
|
||||||
|
async with other_session_maker() as other:
|
||||||
|
result = await other.execute(select(Role).where(Role.id == role_id))
|
||||||
|
fetched = result.scalar_one_or_none()
|
||||||
|
assert fetched is not None, (
|
||||||
|
"Fixture data inserted via get_transaction() must be committed "
|
||||||
|
"and visible to a separate session. If create_db_session uses "
|
||||||
|
"create_db_context, auto-begin forces get_transaction() into "
|
||||||
|
"savepoints instead of real commits."
|
||||||
|
)
|
||||||
|
assert fetched.name == "visible_to_other_session"
|
||||||
|
finally:
|
||||||
|
await other_engine.dispose()
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetXdistWorker:
|
||||||
|
"""Tests for _get_xdist_worker helper."""
|
||||||
|
|
||||||
|
def test_returns_default_test_db_without_env_var(
|
||||||
|
self, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
|
"""Returns default_test_db when PYTEST_XDIST_WORKER is not set."""
|
||||||
|
monkeypatch.delenv("PYTEST_XDIST_WORKER", raising=False)
|
||||||
|
assert _get_xdist_worker("my_default") == "my_default"
|
||||||
|
|
||||||
|
def test_returns_worker_name(self, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""Returns the worker name from the environment variable."""
|
||||||
|
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw0")
|
||||||
|
assert _get_xdist_worker("ignored") == "gw0"
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkerDatabaseUrl:
|
||||||
|
"""Tests for worker_database_url helper."""
|
||||||
|
|
||||||
|
def test_appends_default_test_db_without_xdist(
|
||||||
|
self, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
|
"""default_test_db is appended when not running under xdist."""
|
||||||
|
monkeypatch.delenv("PYTEST_XDIST_WORKER", raising=False)
|
||||||
|
url = "postgresql+asyncpg://user:pass@localhost:5432/mydb"
|
||||||
|
result = worker_database_url(url, default_test_db="fallback")
|
||||||
|
assert make_url(result).database == "mydb_fallback"
|
||||||
|
|
||||||
|
def test_appends_worker_id_to_database_name(self, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""Worker name is appended to the database name."""
|
||||||
|
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw0")
|
||||||
|
url = "postgresql+asyncpg://user:pass@localhost:5432/db"
|
||||||
|
result = worker_database_url(url, default_test_db="unused")
|
||||||
|
assert make_url(result).database == "db_gw0"
|
||||||
|
|
||||||
|
def test_preserves_url_components(self, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""Host, port, username, password, and driver are preserved."""
|
||||||
|
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw2")
|
||||||
|
url = "postgresql+asyncpg://myuser:secret@dbhost:6543/testdb"
|
||||||
|
result = make_url(worker_database_url(url, default_test_db="unused"))
|
||||||
|
|
||||||
|
assert result.drivername == "postgresql+asyncpg"
|
||||||
|
assert result.username == "myuser"
|
||||||
|
assert result.password == "secret"
|
||||||
|
assert result.host == "dbhost"
|
||||||
|
assert result.port == 6543
|
||||||
|
assert result.database == "testdb_gw2"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateWorkerDatabase:
|
||||||
|
"""Tests for create_worker_database context manager."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_creates_default_db_without_xdist(
|
||||||
|
self, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
|
"""Without xdist, creates a database suffixed with default_test_db."""
|
||||||
|
monkeypatch.delenv("PYTEST_XDIST_WORKER", raising=False)
|
||||||
|
default_test_db = "no_xdist_default"
|
||||||
|
expected_db = make_url(
|
||||||
|
worker_database_url(DATABASE_URL, default_test_db=default_test_db)
|
||||||
|
).database
|
||||||
|
|
||||||
|
async with create_worker_database(
|
||||||
|
DATABASE_URL, default_test_db=default_test_db
|
||||||
|
) as url:
|
||||||
|
assert make_url(url).database == expected_db
|
||||||
|
|
||||||
|
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
result = await conn.execute(
|
||||||
|
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||||
|
{"name": expected_db},
|
||||||
|
)
|
||||||
|
assert result.scalar() == 1
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
result = await conn.execute(
|
||||||
|
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||||
|
{"name": expected_db},
|
||||||
|
)
|
||||||
|
assert result.scalar() is None
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_creates_and_drops_worker_database(
|
||||||
|
self, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
|
"""Worker database exists inside the context and is dropped after."""
|
||||||
|
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw_test_create")
|
||||||
|
expected_db = make_url(
|
||||||
|
worker_database_url(DATABASE_URL, default_test_db="unused")
|
||||||
|
).database
|
||||||
|
|
||||||
|
async with create_worker_database(DATABASE_URL) as url:
|
||||||
|
assert make_url(url).database == expected_db
|
||||||
|
|
||||||
|
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
result = await conn.execute(
|
||||||
|
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||||
|
{"name": expected_db},
|
||||||
|
)
|
||||||
|
assert result.scalar() == 1
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
result = await conn.execute(
|
||||||
|
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||||
|
{"name": expected_db},
|
||||||
|
)
|
||||||
|
assert result.scalar() is None
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_cleans_up_stale_database(self, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""A pre-existing worker database is dropped and recreated."""
|
||||||
|
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw_test_stale")
|
||||||
|
expected_db = make_url(
|
||||||
|
worker_database_url(DATABASE_URL, default_test_db="unused")
|
||||||
|
).database
|
||||||
|
|
||||||
|
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
await conn.execute(text(f"DROP DATABASE IF EXISTS {expected_db}"))
|
||||||
|
await conn.execute(text(f"CREATE DATABASE {expected_db}"))
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
async with create_worker_database(DATABASE_URL) as url:
|
||||||
|
assert make_url(url).database == expected_db
|
||||||
|
|
||||||
|
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
result = await conn.execute(
|
||||||
|
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||||
|
{"name": expected_db},
|
||||||
|
)
|
||||||
|
assert result.scalar() is None
|
||||||
|
await engine.dispose()
|
||||||
@@ -1,160 +0,0 @@
|
|||||||
"""Tests for fastapi_toolsets.pytest_plugin module."""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.orm import selectinload
|
|
||||||
|
|
||||||
from fastapi_toolsets.fixtures import Context, FixtureRegistry, register_fixtures
|
|
||||||
|
|
||||||
from .conftest import Role, RoleCrud, User, UserCrud
|
|
||||||
|
|
||||||
test_registry = FixtureRegistry()
|
|
||||||
|
|
||||||
|
|
||||||
@test_registry.register(contexts=[Context.BASE])
|
|
||||||
def roles() -> list[Role]:
|
|
||||||
return [
|
|
||||||
Role(id=1000, name="plugin_admin"),
|
|
||||||
Role(id=1001, name="plugin_user"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@test_registry.register(depends_on=["roles"], contexts=[Context.BASE])
|
|
||||||
def users() -> list[User]:
|
|
||||||
return [
|
|
||||||
User(id=1000, username="plugin_admin", email="padmin@test.com", role_id=1000),
|
|
||||||
User(id=1001, username="plugin_user", email="puser@test.com", role_id=1001),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@test_registry.register(depends_on=["users"], contexts=[Context.TESTING])
|
|
||||||
def extra_users() -> list[User]:
|
|
||||||
return [
|
|
||||||
User(id=1002, username="plugin_extra", email="pextra@test.com", role_id=1001),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
register_fixtures(test_registry, globals())
|
|
||||||
|
|
||||||
|
|
||||||
class TestRegisterFixtures:
|
|
||||||
"""Tests for register_fixtures function."""
|
|
||||||
|
|
||||||
def test_creates_fixtures_in_namespace(self):
|
|
||||||
"""Fixtures are created in the namespace."""
|
|
||||||
assert "fixture_roles" in globals()
|
|
||||||
assert "fixture_users" in globals()
|
|
||||||
assert "fixture_extra_users" in globals()
|
|
||||||
|
|
||||||
def test_fixtures_are_callable(self):
|
|
||||||
"""Created fixtures are callable."""
|
|
||||||
assert callable(globals()["fixture_roles"])
|
|
||||||
assert callable(globals()["fixture_users"])
|
|
||||||
|
|
||||||
|
|
||||||
class TestGeneratedFixtures:
|
|
||||||
"""Tests for the generated pytest fixtures."""
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_fixture_loads_data(
|
|
||||||
self, db_session: AsyncSession, fixture_roles: list[Role]
|
|
||||||
):
|
|
||||||
"""Fixture loads data into database and returns it."""
|
|
||||||
assert len(fixture_roles) == 2
|
|
||||||
assert fixture_roles[0].name == "plugin_admin"
|
|
||||||
assert fixture_roles[1].name == "plugin_user"
|
|
||||||
|
|
||||||
# Verify data is in database
|
|
||||||
count = await RoleCrud.count(db_session, [Role.id >= 1000])
|
|
||||||
assert count == 2
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_fixture_with_dependency(
|
|
||||||
self, db_session: AsyncSession, fixture_users: list[User]
|
|
||||||
):
|
|
||||||
"""Fixture with dependency loads parent fixture first."""
|
|
||||||
# fixture_users depends on fixture_roles
|
|
||||||
# Both should be loaded
|
|
||||||
assert len(fixture_users) == 2
|
|
||||||
|
|
||||||
# Roles should also be in database
|
|
||||||
roles_count = await RoleCrud.count(db_session, [Role.id >= 1000])
|
|
||||||
assert roles_count == 2
|
|
||||||
|
|
||||||
# Users should be in database
|
|
||||||
users_count = await UserCrud.count(db_session, [User.id >= 1000])
|
|
||||||
assert users_count == 2
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_fixture_returns_models(
|
|
||||||
self, db_session: AsyncSession, fixture_users: list[User]
|
|
||||||
):
|
|
||||||
"""Fixture returns actual model instances."""
|
|
||||||
user = fixture_users[0]
|
|
||||||
assert isinstance(user, User)
|
|
||||||
assert user.id == 1000
|
|
||||||
assert user.username == "plugin_admin"
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_fixture_relationships_work(
|
|
||||||
self, db_session: AsyncSession, fixture_users: list[User]
|
|
||||||
):
|
|
||||||
"""Loaded fixtures have working relationships."""
|
|
||||||
# Load user with role relationship
|
|
||||||
user = await UserCrud.get(
|
|
||||||
db_session,
|
|
||||||
[User.id == 1000],
|
|
||||||
load_options=[selectinload(User.role)],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert user.role is not None
|
|
||||||
assert user.role.name == "plugin_admin"
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_chained_dependencies(
|
|
||||||
self, db_session: AsyncSession, fixture_extra_users: list[User]
|
|
||||||
):
|
|
||||||
"""Chained dependencies are resolved correctly."""
|
|
||||||
# fixture_extra_users -> fixture_users -> fixture_roles
|
|
||||||
assert len(fixture_extra_users) == 1
|
|
||||||
|
|
||||||
# All fixtures should be loaded
|
|
||||||
roles_count = await RoleCrud.count(db_session, [Role.id >= 1000])
|
|
||||||
users_count = await UserCrud.count(db_session, [User.id >= 1000])
|
|
||||||
|
|
||||||
assert roles_count == 2
|
|
||||||
assert users_count == 3 # 2 from users + 1 from extra_users
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_can_query_loaded_data(
|
|
||||||
self, db_session: AsyncSession, fixture_users: list[User]
|
|
||||||
):
|
|
||||||
"""Can query the loaded fixture data."""
|
|
||||||
# Get all users loaded by fixture
|
|
||||||
users = await UserCrud.get_multi(
|
|
||||||
db_session,
|
|
||||||
filters=[User.id >= 1000],
|
|
||||||
order_by=User.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(users) == 2
|
|
||||||
assert users[0].username == "plugin_admin"
|
|
||||||
assert users[1].username == "plugin_user"
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_multiple_fixtures_in_same_test(
|
|
||||||
self,
|
|
||||||
db_session: AsyncSession,
|
|
||||||
fixture_roles: list[Role],
|
|
||||||
fixture_users: list[User],
|
|
||||||
):
|
|
||||||
"""Multiple fixtures can be used in the same test."""
|
|
||||||
assert len(fixture_roles) == 2
|
|
||||||
assert len(fixture_users) == 2
|
|
||||||
|
|
||||||
# Both should be in database
|
|
||||||
roles = await RoleCrud.get_multi(db_session, filters=[Role.id >= 1000])
|
|
||||||
users = await UserCrud.get_multi(db_session, filters=[User.id >= 1000])
|
|
||||||
|
|
||||||
assert len(roles) == 2
|
|
||||||
assert len(users) == 2
|
|
||||||
@@ -5,9 +5,13 @@ from pydantic import ValidationError
|
|||||||
|
|
||||||
from fastapi_toolsets.schemas import (
|
from fastapi_toolsets.schemas import (
|
||||||
ApiError,
|
ApiError,
|
||||||
|
CursorPagination,
|
||||||
|
CursorPaginatedResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
|
OffsetPagination,
|
||||||
|
OffsetPaginatedResponse,
|
||||||
PaginatedResponse,
|
PaginatedResponse,
|
||||||
Pagination,
|
PaginationType,
|
||||||
Response,
|
Response,
|
||||||
ResponseStatus,
|
ResponseStatus,
|
||||||
)
|
)
|
||||||
@@ -46,6 +50,31 @@ class TestApiError:
|
|||||||
assert error.desc == "The resource was not found."
|
assert error.desc == "The resource was not found."
|
||||||
assert error.err_code == "RES-404"
|
assert error.err_code == "RES-404"
|
||||||
|
|
||||||
|
def test_data_defaults_to_none(self):
|
||||||
|
"""ApiError data field defaults to None."""
|
||||||
|
error = ApiError(
|
||||||
|
code=404,
|
||||||
|
msg="Not Found",
|
||||||
|
desc="The resource was not found.",
|
||||||
|
err_code="RES-404",
|
||||||
|
)
|
||||||
|
assert error.data is None
|
||||||
|
|
||||||
|
def test_create_with_data(self):
|
||||||
|
"""ApiError can be created with a data payload."""
|
||||||
|
error = ApiError(
|
||||||
|
code=422,
|
||||||
|
msg="Validation Error",
|
||||||
|
desc="2 validation error(s) detected",
|
||||||
|
err_code="VAL-422",
|
||||||
|
data={
|
||||||
|
"errors": [{"field": "name", "message": "required", "type": "missing"}]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert error.data == {
|
||||||
|
"errors": [{"field": "name", "message": "required", "type": "missing"}]
|
||||||
|
}
|
||||||
|
|
||||||
def test_requires_all_fields(self):
|
def test_requires_all_fields(self):
|
||||||
"""ApiError requires all fields."""
|
"""ApiError requires all fields."""
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
@@ -129,12 +158,12 @@ class TestErrorResponse:
|
|||||||
assert data["description"] == "Details"
|
assert data["description"] == "Details"
|
||||||
|
|
||||||
|
|
||||||
class TestPagination:
|
class TestOffsetPagination:
|
||||||
"""Tests for Pagination schema."""
|
"""Tests for OffsetPagination schema (canonical name for offset-based pagination)."""
|
||||||
|
|
||||||
def test_create_pagination(self):
|
def test_create_pagination(self):
|
||||||
"""Create Pagination with all fields."""
|
"""Create OffsetPagination with all fields."""
|
||||||
pagination = Pagination(
|
pagination = OffsetPagination(
|
||||||
total_count=100,
|
total_count=100,
|
||||||
items_per_page=10,
|
items_per_page=10,
|
||||||
page=1,
|
page=1,
|
||||||
@@ -148,7 +177,7 @@ class TestPagination:
|
|||||||
|
|
||||||
def test_last_page_has_more_false(self):
|
def test_last_page_has_more_false(self):
|
||||||
"""Last page has has_more=False."""
|
"""Last page has has_more=False."""
|
||||||
pagination = Pagination(
|
pagination = OffsetPagination(
|
||||||
total_count=25,
|
total_count=25,
|
||||||
items_per_page=10,
|
items_per_page=10,
|
||||||
page=3,
|
page=3,
|
||||||
@@ -158,8 +187,8 @@ class TestPagination:
|
|||||||
assert pagination.has_more is False
|
assert pagination.has_more is False
|
||||||
|
|
||||||
def test_serialization(self):
|
def test_serialization(self):
|
||||||
"""Pagination serializes correctly."""
|
"""OffsetPagination serializes correctly."""
|
||||||
pagination = Pagination(
|
pagination = OffsetPagination(
|
||||||
total_count=50,
|
total_count=50,
|
||||||
items_per_page=20,
|
items_per_page=20,
|
||||||
page=2,
|
page=2,
|
||||||
@@ -172,13 +201,152 @@ class TestPagination:
|
|||||||
assert data["page"] == 2
|
assert data["page"] == 2
|
||||||
assert data["has_more"] is True
|
assert data["has_more"] is True
|
||||||
|
|
||||||
|
def test_total_count_can_be_none(self):
|
||||||
|
"""total_count accepts None (include_total=False mode)."""
|
||||||
|
pagination = OffsetPagination(
|
||||||
|
total_count=None,
|
||||||
|
items_per_page=20,
|
||||||
|
page=1,
|
||||||
|
has_more=True,
|
||||||
|
)
|
||||||
|
assert pagination.total_count is None
|
||||||
|
|
||||||
|
def test_serialization_with_none_total_count(self):
|
||||||
|
"""OffsetPagination serializes total_count=None correctly."""
|
||||||
|
pagination = OffsetPagination(
|
||||||
|
total_count=None,
|
||||||
|
items_per_page=20,
|
||||||
|
page=1,
|
||||||
|
has_more=False,
|
||||||
|
)
|
||||||
|
data = pagination.model_dump()
|
||||||
|
assert data["total_count"] is None
|
||||||
|
|
||||||
|
def test_pages_computed(self):
|
||||||
|
"""pages is ceil(total_count / items_per_page)."""
|
||||||
|
pagination = OffsetPagination(
|
||||||
|
total_count=42,
|
||||||
|
items_per_page=10,
|
||||||
|
page=1,
|
||||||
|
has_more=True,
|
||||||
|
)
|
||||||
|
assert pagination.pages == 5
|
||||||
|
|
||||||
|
def test_pages_exact_division(self):
|
||||||
|
"""pages is exact when total_count is evenly divisible."""
|
||||||
|
pagination = OffsetPagination(
|
||||||
|
total_count=40,
|
||||||
|
items_per_page=10,
|
||||||
|
page=1,
|
||||||
|
has_more=False,
|
||||||
|
)
|
||||||
|
assert pagination.pages == 4
|
||||||
|
|
||||||
|
def test_pages_zero_total(self):
|
||||||
|
"""pages is 0 when total_count is 0."""
|
||||||
|
pagination = OffsetPagination(
|
||||||
|
total_count=0,
|
||||||
|
items_per_page=10,
|
||||||
|
page=1,
|
||||||
|
has_more=False,
|
||||||
|
)
|
||||||
|
assert pagination.pages == 0
|
||||||
|
|
||||||
|
def test_pages_zero_items_per_page(self):
|
||||||
|
"""pages is 0 when items_per_page is 0."""
|
||||||
|
pagination = OffsetPagination(
|
||||||
|
total_count=100,
|
||||||
|
items_per_page=0,
|
||||||
|
page=1,
|
||||||
|
has_more=False,
|
||||||
|
)
|
||||||
|
assert pagination.pages == 0
|
||||||
|
|
||||||
|
def test_pages_none_when_total_count_none(self):
|
||||||
|
"""pages is None when total_count is None (include_total=False)."""
|
||||||
|
pagination = OffsetPagination(
|
||||||
|
total_count=None,
|
||||||
|
items_per_page=20,
|
||||||
|
page=1,
|
||||||
|
has_more=True,
|
||||||
|
)
|
||||||
|
assert pagination.pages is None
|
||||||
|
|
||||||
|
def test_pages_in_serialization(self):
|
||||||
|
"""pages appears in model_dump output."""
|
||||||
|
pagination = OffsetPagination(
|
||||||
|
total_count=25,
|
||||||
|
items_per_page=10,
|
||||||
|
page=1,
|
||||||
|
has_more=True,
|
||||||
|
)
|
||||||
|
data = pagination.model_dump()
|
||||||
|
assert data["pages"] == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestCursorPagination:
|
||||||
|
"""Tests for CursorPagination schema."""
|
||||||
|
|
||||||
|
def test_create_with_next_cursor(self):
|
||||||
|
"""CursorPagination with a next cursor indicates more pages."""
|
||||||
|
pagination = CursorPagination(
|
||||||
|
next_cursor="eyJ2YWx1ZSI6ICIxMjMifQ==",
|
||||||
|
items_per_page=20,
|
||||||
|
has_more=True,
|
||||||
|
)
|
||||||
|
assert pagination.next_cursor == "eyJ2YWx1ZSI6ICIxMjMifQ=="
|
||||||
|
assert pagination.prev_cursor is None
|
||||||
|
assert pagination.items_per_page == 20
|
||||||
|
assert pagination.has_more is True
|
||||||
|
|
||||||
|
def test_create_last_page(self):
|
||||||
|
"""CursorPagination for the last page has next_cursor=None and has_more=False."""
|
||||||
|
pagination = CursorPagination(
|
||||||
|
next_cursor=None,
|
||||||
|
items_per_page=20,
|
||||||
|
has_more=False,
|
||||||
|
)
|
||||||
|
assert pagination.next_cursor is None
|
||||||
|
assert pagination.has_more is False
|
||||||
|
|
||||||
|
def test_prev_cursor_defaults_to_none(self):
|
||||||
|
"""prev_cursor defaults to None."""
|
||||||
|
pagination = CursorPagination(
|
||||||
|
next_cursor=None, items_per_page=10, has_more=False
|
||||||
|
)
|
||||||
|
assert pagination.prev_cursor is None
|
||||||
|
|
||||||
|
def test_prev_cursor_can_be_set(self):
|
||||||
|
"""prev_cursor can be explicitly set."""
|
||||||
|
pagination = CursorPagination(
|
||||||
|
next_cursor="next123",
|
||||||
|
prev_cursor="prev456",
|
||||||
|
items_per_page=10,
|
||||||
|
has_more=True,
|
||||||
|
)
|
||||||
|
assert pagination.prev_cursor == "prev456"
|
||||||
|
|
||||||
|
def test_serialization(self):
|
||||||
|
"""CursorPagination serializes correctly."""
|
||||||
|
pagination = CursorPagination(
|
||||||
|
next_cursor="abc123",
|
||||||
|
prev_cursor="xyz789",
|
||||||
|
items_per_page=20,
|
||||||
|
has_more=True,
|
||||||
|
)
|
||||||
|
data = pagination.model_dump()
|
||||||
|
assert data["next_cursor"] == "abc123"
|
||||||
|
assert data["prev_cursor"] == "xyz789"
|
||||||
|
assert data["items_per_page"] == 20
|
||||||
|
assert data["has_more"] is True
|
||||||
|
|
||||||
|
|
||||||
class TestPaginatedResponse:
|
class TestPaginatedResponse:
|
||||||
"""Tests for PaginatedResponse schema."""
|
"""Tests for PaginatedResponse schema."""
|
||||||
|
|
||||||
def test_create_paginated_response(self):
|
def test_create_paginated_response(self):
|
||||||
"""Create PaginatedResponse with data and pagination."""
|
"""Create PaginatedResponse with data and pagination."""
|
||||||
pagination = Pagination(
|
pagination = OffsetPagination(
|
||||||
total_count=30,
|
total_count=30,
|
||||||
items_per_page=10,
|
items_per_page=10,
|
||||||
page=1,
|
page=1,
|
||||||
@@ -189,13 +357,14 @@ class TestPaginatedResponse:
|
|||||||
pagination=pagination,
|
pagination=pagination,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert isinstance(response.pagination, OffsetPagination)
|
||||||
assert len(response.data) == 2
|
assert len(response.data) == 2
|
||||||
assert response.pagination.total_count == 30
|
assert response.pagination.total_count == 30
|
||||||
assert response.status == ResponseStatus.SUCCESS
|
assert response.status == ResponseStatus.SUCCESS
|
||||||
|
|
||||||
def test_with_custom_message(self):
|
def test_with_custom_message(self):
|
||||||
"""PaginatedResponse with custom message."""
|
"""PaginatedResponse with custom message."""
|
||||||
pagination = Pagination(
|
pagination = OffsetPagination(
|
||||||
total_count=5,
|
total_count=5,
|
||||||
items_per_page=10,
|
items_per_page=10,
|
||||||
page=1,
|
page=1,
|
||||||
@@ -211,28 +380,48 @@ class TestPaginatedResponse:
|
|||||||
|
|
||||||
def test_empty_data(self):
|
def test_empty_data(self):
|
||||||
"""PaginatedResponse with empty data."""
|
"""PaginatedResponse with empty data."""
|
||||||
pagination = Pagination(
|
pagination = OffsetPagination(
|
||||||
total_count=0,
|
total_count=0,
|
||||||
items_per_page=10,
|
items_per_page=10,
|
||||||
page=1,
|
page=1,
|
||||||
has_more=False,
|
has_more=False,
|
||||||
)
|
)
|
||||||
response = PaginatedResponse[dict](
|
response = PaginatedResponse(
|
||||||
data=[],
|
data=[],
|
||||||
pagination=pagination,
|
pagination=pagination,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert isinstance(response.pagination, OffsetPagination)
|
||||||
assert response.data == []
|
assert response.data == []
|
||||||
assert response.pagination.total_count == 0
|
assert response.pagination.total_count == 0
|
||||||
|
|
||||||
|
def test_class_getitem_with_concrete_type_returns_discriminated_union(self):
|
||||||
|
"""PaginatedResponse[T] with a concrete type returns a discriminated Annotated union."""
|
||||||
|
import typing
|
||||||
|
|
||||||
|
alias = PaginatedResponse[dict]
|
||||||
|
args = typing.get_args(alias)
|
||||||
|
# args[0] is the Union, args[1] is the FieldInfo discriminator
|
||||||
|
union_args = typing.get_args(args[0])
|
||||||
|
assert CursorPaginatedResponse[dict] in union_args
|
||||||
|
assert OffsetPaginatedResponse[dict] in union_args
|
||||||
|
|
||||||
|
def test_class_getitem_is_cached(self):
|
||||||
|
"""Repeated subscripting with the same type returns the identical cached object."""
|
||||||
|
assert PaginatedResponse[dict] is PaginatedResponse[dict]
|
||||||
|
|
||||||
|
def test_class_getitem_with_typevar_returns_generic(self):
|
||||||
|
"""PaginatedResponse[TypeVar] falls through to Pydantic generic parametrisation."""
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
alias = PaginatedResponse[T]
|
||||||
|
# Should be a generic alias, not an Annotated union
|
||||||
|
assert not hasattr(alias, "__metadata__")
|
||||||
|
|
||||||
def test_generic_type_hint(self):
|
def test_generic_type_hint(self):
|
||||||
"""PaginatedResponse supports generic type hints."""
|
"""PaginatedResponse supports generic type hints."""
|
||||||
|
pagination = OffsetPagination(
|
||||||
class UserOut:
|
|
||||||
id: int
|
|
||||||
name: str
|
|
||||||
|
|
||||||
pagination = Pagination(
|
|
||||||
total_count=1,
|
total_count=1,
|
||||||
items_per_page=10,
|
items_per_page=10,
|
||||||
page=1,
|
page=1,
|
||||||
@@ -247,7 +436,7 @@ class TestPaginatedResponse:
|
|||||||
|
|
||||||
def test_serialization(self):
|
def test_serialization(self):
|
||||||
"""PaginatedResponse serializes correctly."""
|
"""PaginatedResponse serializes correctly."""
|
||||||
pagination = Pagination(
|
pagination = OffsetPagination(
|
||||||
total_count=100,
|
total_count=100,
|
||||||
items_per_page=10,
|
items_per_page=10,
|
||||||
page=5,
|
page=5,
|
||||||
@@ -265,6 +454,211 @@ class TestPaginatedResponse:
|
|||||||
assert data["data"] == ["item1", "item2"]
|
assert data["data"] == ["item1", "item2"]
|
||||||
assert data["pagination"]["page"] == 5
|
assert data["pagination"]["page"] == 5
|
||||||
|
|
||||||
|
def test_pagination_field_accepts_offset_pagination(self):
|
||||||
|
"""PaginatedResponse.pagination accepts OffsetPagination."""
|
||||||
|
response = PaginatedResponse(
|
||||||
|
data=[1, 2],
|
||||||
|
pagination=OffsetPagination(
|
||||||
|
total_count=2, items_per_page=10, page=1, has_more=False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert isinstance(response.pagination, OffsetPagination)
|
||||||
|
|
||||||
|
def test_pagination_field_accepts_cursor_pagination(self):
|
||||||
|
"""PaginatedResponse.pagination accepts CursorPagination."""
|
||||||
|
response = PaginatedResponse(
|
||||||
|
data=[1, 2],
|
||||||
|
pagination=CursorPagination(
|
||||||
|
next_cursor=None, items_per_page=10, has_more=False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert isinstance(response.pagination, CursorPagination)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPaginationType:
|
||||||
|
"""Tests for PaginationType enum."""
|
||||||
|
|
||||||
|
def test_offset_value(self):
|
||||||
|
"""OFFSET has string value 'offset'."""
|
||||||
|
assert PaginationType.OFFSET == "offset"
|
||||||
|
assert PaginationType.OFFSET.value == "offset"
|
||||||
|
|
||||||
|
def test_cursor_value(self):
|
||||||
|
"""CURSOR has string value 'cursor'."""
|
||||||
|
assert PaginationType.CURSOR == "cursor"
|
||||||
|
assert PaginationType.CURSOR.value == "cursor"
|
||||||
|
|
||||||
|
def test_is_string_enum(self):
|
||||||
|
"""PaginationType is a string enum."""
|
||||||
|
assert isinstance(PaginationType.OFFSET, str)
|
||||||
|
assert isinstance(PaginationType.CURSOR, str)
|
||||||
|
|
||||||
|
def test_members(self):
|
||||||
|
"""PaginationType has exactly two members."""
|
||||||
|
assert set(PaginationType) == {PaginationType.OFFSET, PaginationType.CURSOR}
|
||||||
|
|
||||||
|
|
||||||
|
class TestOffsetPaginatedResponse:
|
||||||
|
"""Tests for OffsetPaginatedResponse schema."""
|
||||||
|
|
||||||
|
def test_pagination_type_is_offset(self):
|
||||||
|
"""pagination_type is always PaginationType.OFFSET."""
|
||||||
|
response = OffsetPaginatedResponse(
|
||||||
|
data=[],
|
||||||
|
pagination=OffsetPagination(
|
||||||
|
total_count=0, items_per_page=10, page=1, has_more=False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert response.pagination_type is PaginationType.OFFSET
|
||||||
|
|
||||||
|
def test_pagination_type_serializes_to_string(self):
|
||||||
|
"""pagination_type serializes to 'offset' in JSON mode."""
|
||||||
|
response = OffsetPaginatedResponse(
|
||||||
|
data=[],
|
||||||
|
pagination=OffsetPagination(
|
||||||
|
total_count=0, items_per_page=10, page=1, has_more=False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert response.model_dump(mode="json")["pagination_type"] == "offset"
|
||||||
|
|
||||||
|
def test_pagination_field_is_typed(self):
|
||||||
|
"""pagination field is OffsetPagination, not the union."""
|
||||||
|
response = OffsetPaginatedResponse(
|
||||||
|
data=[{"id": 1}],
|
||||||
|
pagination=OffsetPagination(
|
||||||
|
total_count=10, items_per_page=5, page=2, has_more=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert isinstance(response.pagination, OffsetPagination)
|
||||||
|
assert response.pagination.total_count == 10
|
||||||
|
assert response.pagination.page == 2
|
||||||
|
|
||||||
|
def test_is_subclass_of_paginated_response(self):
|
||||||
|
"""OffsetPaginatedResponse IS a PaginatedResponse."""
|
||||||
|
response = OffsetPaginatedResponse(
|
||||||
|
data=[],
|
||||||
|
pagination=OffsetPagination(
|
||||||
|
total_count=0, items_per_page=10, page=1, has_more=False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert isinstance(response, PaginatedResponse)
|
||||||
|
|
||||||
|
def test_pagination_type_default_cannot_be_overridden_to_cursor(self):
|
||||||
|
"""pagination_type rejects values other than OFFSET."""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
OffsetPaginatedResponse(
|
||||||
|
data=[],
|
||||||
|
pagination=OffsetPagination(
|
||||||
|
total_count=0, items_per_page=10, page=1, has_more=False
|
||||||
|
),
|
||||||
|
pagination_type=PaginationType.CURSOR, # type: ignore[arg-type] # ty:ignore[invalid-argument-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_filter_attributes_defaults_to_none(self):
|
||||||
|
"""filter_attributes defaults to None."""
|
||||||
|
response = OffsetPaginatedResponse(
|
||||||
|
data=[],
|
||||||
|
pagination=OffsetPagination(
|
||||||
|
total_count=0, items_per_page=10, page=1, has_more=False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert response.filter_attributes is None
|
||||||
|
|
||||||
|
def test_full_serialization(self):
|
||||||
|
"""Full JSON serialization includes all expected fields."""
|
||||||
|
response = OffsetPaginatedResponse(
|
||||||
|
data=[{"id": 1}],
|
||||||
|
pagination=OffsetPagination(
|
||||||
|
total_count=1, items_per_page=10, page=1, has_more=False
|
||||||
|
),
|
||||||
|
filter_attributes={"status": ["active"]},
|
||||||
|
)
|
||||||
|
data = response.model_dump(mode="json")
|
||||||
|
|
||||||
|
assert data["pagination_type"] == "offset"
|
||||||
|
assert data["status"] == "SUCCESS"
|
||||||
|
assert data["data"] == [{"id": 1}]
|
||||||
|
assert data["pagination"]["total_count"] == 1
|
||||||
|
assert data["filter_attributes"] == {"status": ["active"]}
|
||||||
|
|
||||||
|
|
||||||
|
class TestCursorPaginatedResponse:
|
||||||
|
"""Tests for CursorPaginatedResponse schema."""
|
||||||
|
|
||||||
|
def test_pagination_type_is_cursor(self):
|
||||||
|
"""pagination_type is always PaginationType.CURSOR."""
|
||||||
|
response = CursorPaginatedResponse(
|
||||||
|
data=[],
|
||||||
|
pagination=CursorPagination(
|
||||||
|
next_cursor=None, items_per_page=10, has_more=False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert response.pagination_type is PaginationType.CURSOR
|
||||||
|
|
||||||
|
def test_pagination_type_serializes_to_string(self):
|
||||||
|
"""pagination_type serializes to 'cursor' in JSON mode."""
|
||||||
|
response = CursorPaginatedResponse(
|
||||||
|
data=[],
|
||||||
|
pagination=CursorPagination(
|
||||||
|
next_cursor=None, items_per_page=10, has_more=False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert response.model_dump(mode="json")["pagination_type"] == "cursor"
|
||||||
|
|
||||||
|
def test_pagination_field_is_typed(self):
|
||||||
|
"""pagination field is CursorPagination, not the union."""
|
||||||
|
response = CursorPaginatedResponse(
|
||||||
|
data=[{"id": 1}],
|
||||||
|
pagination=CursorPagination(
|
||||||
|
next_cursor="abc123",
|
||||||
|
prev_cursor=None,
|
||||||
|
items_per_page=20,
|
||||||
|
has_more=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert isinstance(response.pagination, CursorPagination)
|
||||||
|
assert response.pagination.next_cursor == "abc123"
|
||||||
|
assert response.pagination.has_more is True
|
||||||
|
|
||||||
|
def test_is_subclass_of_paginated_response(self):
|
||||||
|
"""CursorPaginatedResponse IS a PaginatedResponse."""
|
||||||
|
response = CursorPaginatedResponse(
|
||||||
|
data=[],
|
||||||
|
pagination=CursorPagination(
|
||||||
|
next_cursor=None, items_per_page=10, has_more=False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert isinstance(response, PaginatedResponse)
|
||||||
|
|
||||||
|
def test_pagination_type_default_cannot_be_overridden_to_offset(self):
|
||||||
|
"""pagination_type rejects values other than CURSOR."""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
CursorPaginatedResponse(
|
||||||
|
data=[],
|
||||||
|
pagination=CursorPagination(
|
||||||
|
next_cursor=None, items_per_page=10, has_more=False
|
||||||
|
),
|
||||||
|
pagination_type=PaginationType.OFFSET, # type: ignore[arg-type] # ty:ignore[invalid-argument-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_full_serialization(self):
|
||||||
|
"""Full JSON serialization includes all expected fields."""
|
||||||
|
response = CursorPaginatedResponse(
|
||||||
|
data=[{"id": 1}],
|
||||||
|
pagination=CursorPagination(
|
||||||
|
next_cursor="tok_next",
|
||||||
|
prev_cursor="tok_prev",
|
||||||
|
items_per_page=10,
|
||||||
|
has_more=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
data = response.model_dump(mode="json")
|
||||||
|
|
||||||
|
assert data["pagination_type"] == "cursor"
|
||||||
|
assert data["status"] == "SUCCESS"
|
||||||
|
assert data["pagination"]["next_cursor"] == "tok_next"
|
||||||
|
assert data["pagination"]["prev_cursor"] == "tok_prev"
|
||||||
|
|
||||||
|
|
||||||
class TestFromAttributes:
|
class TestFromAttributes:
|
||||||
"""Tests for from_attributes config (ORM mode)."""
|
"""Tests for from_attributes config (ORM mode)."""
|
||||||
|
|||||||
148
zensical.toml
Normal file
148
zensical.toml
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
[project]
|
||||||
|
site_name = "FastAPI Toolsets"
|
||||||
|
site_description = "Production-ready utilities for FastAPI applications."
|
||||||
|
site_author = "d3vyce"
|
||||||
|
site_url = "https://fastapi-toolsets.d3vyce.fr"
|
||||||
|
copyright = "Copyright © 2026 d3vyce"
|
||||||
|
repo_url = "https://github.com/d3vyce/fastapi-toolsets"
|
||||||
|
|
||||||
|
[project.theme]
|
||||||
|
custom_dir = "docs/overrides"
|
||||||
|
language = "en"
|
||||||
|
features = [
|
||||||
|
"announce.dismiss",
|
||||||
|
"content.action.view",
|
||||||
|
"content.code.annotate",
|
||||||
|
"content.code.copy",
|
||||||
|
"content.code.select",
|
||||||
|
"content.footnote.tooltips",
|
||||||
|
"content.tabs.link",
|
||||||
|
"content.tooltips",
|
||||||
|
"navigation.footer",
|
||||||
|
"navigation.indexes",
|
||||||
|
"navigation.instant",
|
||||||
|
"navigation.instant.prefetch",
|
||||||
|
"navigation.path",
|
||||||
|
"navigation.sections",
|
||||||
|
"navigation.tabs",
|
||||||
|
"navigation.top",
|
||||||
|
"navigation.tracking",
|
||||||
|
"search.highlight",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[project.theme.palette]]
|
||||||
|
scheme = "default"
|
||||||
|
toggle.icon = "lucide/sun"
|
||||||
|
toggle.name = "Switch to dark mode"
|
||||||
|
|
||||||
|
[[project.theme.palette]]
|
||||||
|
scheme = "slate"
|
||||||
|
toggle.icon = "lucide/moon"
|
||||||
|
toggle.name = "Switch to light mode"
|
||||||
|
|
||||||
|
[project.theme.font]
|
||||||
|
text = "Inter"
|
||||||
|
code = "Jetbrains Mono"
|
||||||
|
|
||||||
|
[project.theme.icon]
|
||||||
|
repo = "fontawesome/brands/github"
|
||||||
|
|
||||||
|
[project.plugins.mkdocstrings.handlers.python]
|
||||||
|
inventories = ["https://docs.python.org/3/objects.inv"]
|
||||||
|
paths = ["src"]
|
||||||
|
|
||||||
|
[project.plugins.mkdocstrings.handlers.python.options]
|
||||||
|
docstring_style = "google"
|
||||||
|
inherited_members = true
|
||||||
|
show_source = false
|
||||||
|
show_root_heading = true
|
||||||
|
|
||||||
|
[project.markdown_extensions]
|
||||||
|
abbr = {}
|
||||||
|
admonition = {}
|
||||||
|
attr_list = {}
|
||||||
|
def_list = {}
|
||||||
|
footnotes = {}
|
||||||
|
md_in_html = {}
|
||||||
|
"pymdownx.arithmatex" = {generic = true}
|
||||||
|
"pymdownx.betterem" = {}
|
||||||
|
"pymdownx.caret" = {}
|
||||||
|
"pymdownx.details" = {}
|
||||||
|
"pymdownx.emoji" = {}
|
||||||
|
"pymdownx.inlinehilite" = {}
|
||||||
|
"pymdownx.keys" = {}
|
||||||
|
"pymdownx.magiclink" = {}
|
||||||
|
"pymdownx.mark" = {}
|
||||||
|
"pymdownx.smartsymbols" = {}
|
||||||
|
"pymdownx.tasklist" = {custom_checkbox = true}
|
||||||
|
"pymdownx.tilde" = {}
|
||||||
|
|
||||||
|
[project.markdown_extensions.pymdownx.emoji]
|
||||||
|
emoji_index = "zensical.extensions.emoji.twemoji"
|
||||||
|
emoji_generator = "zensical.extensions.emoji.to_svg"
|
||||||
|
|
||||||
|
[project.markdown_extensions."pymdownx.highlight"]
|
||||||
|
anchor_linenums = true
|
||||||
|
line_spans = "__span"
|
||||||
|
pygments_lang_class = true
|
||||||
|
|
||||||
|
[project.markdown_extensions."pymdownx.superfences"]
|
||||||
|
custom_fences = [{name = "mermaid", class = "mermaid"}]
|
||||||
|
|
||||||
|
[project.markdown_extensions."pymdownx.tabbed"]
|
||||||
|
alternate_style = true
|
||||||
|
combine_header_slug = true
|
||||||
|
|
||||||
|
[project.markdown_extensions."toc"]
|
||||||
|
permalink = true
|
||||||
|
|
||||||
|
[project.markdown_extensions."pymdownx.snippets"]
|
||||||
|
base_path = ["."]
|
||||||
|
check_paths = true
|
||||||
|
|
||||||
|
[[project.nav]]
|
||||||
|
Home = "index.md"
|
||||||
|
|
||||||
|
[[project.nav]]
|
||||||
|
Modules = [
|
||||||
|
{CLI = "module/cli.md"},
|
||||||
|
{CRUD = "module/crud.md"},
|
||||||
|
{Database = "module/db.md"},
|
||||||
|
{Dependencies = "module/dependencies.md"},
|
||||||
|
{Exceptions = "module/exceptions.md"},
|
||||||
|
{Fixtures = "module/fixtures.md"},
|
||||||
|
{Logger = "module/logger.md"},
|
||||||
|
{Metrics = "module/metrics.md"},
|
||||||
|
{Models = "module/models.md"},
|
||||||
|
{Pytest = "module/pytest.md"},
|
||||||
|
{Schemas = "module/schemas.md"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[[project.nav]]
|
||||||
|
Reference = [
|
||||||
|
{CLI = "reference/cli.md"},
|
||||||
|
{CRUD = "reference/crud.md"},
|
||||||
|
{Database = "reference/db.md"},
|
||||||
|
{Dependencies = "reference/dependencies.md"},
|
||||||
|
{Exceptions = "reference/exceptions.md"},
|
||||||
|
{Fixtures = "reference/fixtures.md"},
|
||||||
|
{Logger = "reference/logger.md"},
|
||||||
|
{Metrics = "reference/metrics.md"},
|
||||||
|
{Models = "reference/models.md"},
|
||||||
|
{Pytest = "reference/pytest.md"},
|
||||||
|
{Schemas = "reference/schemas.md"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[[project.nav]]
|
||||||
|
Examples = [
|
||||||
|
{"Pagination & Search" = "examples/pagination-search.md"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[[project.nav]]
|
||||||
|
Migration = [
|
||||||
|
{"v3.0" = "migration/v3.md"},
|
||||||
|
{"v2.0" = "migration/v2.md"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[[project.nav]]
|
||||||
|
"Changelog ↗" = "https://github.com/d3vyce/fastapi-toolsets/releases"
|
||||||
Reference in New Issue
Block a user