mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-16 14:46:24 +02:00
Compare commits
120 Commits
v0.3.0
...
337985ef38
| Author | SHA1 | Date | |
|---|---|---|---|
|
337985ef38
|
|||
|
|
b5e6dfe6fe | ||
|
|
6681b7ade7 | ||
|
|
6981c33dc8 | ||
|
|
0c7a99039c | ||
|
|
bcb5b0bfda | ||
|
100e1c1aa9
|
|||
|
|
db6c7a565f | ||
|
|
768e405554 | ||
|
|
f0223ebde4 | ||
|
|
f8c9bf69fe | ||
|
6d6fae5538
|
|||
|
|
fc9cd1f034 | ||
|
|
f82225f995 | ||
|
|
e62612a93a | ||
|
|
56f0ea291e | ||
|
|
ee896009ee | ||
|
|
65bf928e12 | ||
|
|
2e9c6c0c90 | ||
|
2c494fcd17
|
|||
|
|
fd7269a372 | ||
|
|
c863744012 | ||
|
|
aedcbf4e04 | ||
|
|
19c013bdec | ||
|
|
81407c3038 | ||
|
|
0fb00d44da | ||
|
|
19232d3436 | ||
|
1eafcb3873
|
|||
|
|
0d67fbb58d | ||
|
|
a59f098930 | ||
|
|
96e34ba8af | ||
|
|
26d649791f | ||
|
dde5183e68
|
|||
|
|
e4250a9910 | ||
|
|
4800941934 | ||
|
0cc21d2012
|
|||
|
|
a3245d50f0 | ||
|
|
baebf022f6 | ||
|
|
96d445e3f3 | ||
|
|
80306e1af3 | ||
|
|
fd999b63f1 | ||
|
|
c0f352b914 | ||
|
c4c760484b
|
|||
|
|
432e0722e0 | ||
|
|
e732e54518 | ||
|
|
05b5a2c876 | ||
|
|
4a020c56d1 | ||
|
56d365d14b
|
|||
|
|
a257d85d45 | ||
|
117675d02f
|
|||
|
|
d7ad7308c5 | ||
|
8d57bf9525
|
|||
|
|
5a08ec2f57 | ||
|
|
433dc55fcd | ||
|
|
0b2abd8c43 | ||
|
|
07c99be89b | ||
|
|
9b75cc7dfc | ||
|
|
6144b383eb | ||
|
7ec407834a
|
|||
|
|
7da34f33a2 | ||
|
8c8911fb27
|
|||
|
|
c0c3b38054 | ||
|
e17d385910
|
|||
|
|
6cf7df55ef | ||
|
|
7482bc5dad | ||
|
|
9d07dfea85 | ||
|
|
31678935aa | ||
|
|
823a0b3e36 | ||
|
1591cd3d64
|
|||
|
|
6714ceeb92 | ||
|
|
73fae04333 | ||
|
|
32ed36e102 | ||
|
|
48567310bc | ||
|
|
de51ed4675 | ||
|
|
794767edbb | ||
|
|
9c136f05bb | ||
|
3299a439fe
|
|||
|
|
d5b22a72fd | ||
|
|
c32f2e18be | ||
|
d971261f98
|
|||
|
|
74a54b7396 | ||
|
|
19805ab376 | ||
|
|
d4498e2063 | ||
| f59c1a17e2 | |||
|
|
8982ba18e3 | ||
|
|
71fe6f478f | ||
|
|
1cfbf14986 | ||
|
|
e3ff535b7e | ||
|
|
8825c772ce | ||
|
|
c8c263ca8f | ||
|
2020fa2f92
|
|||
|
|
1ea316bef4 | ||
|
|
ced1a655f2 | ||
|
|
290b2a06ec | ||
|
|
baa9711665 | ||
|
d526969d0e
|
|||
|
|
e24153053e | ||
|
348ed4c148
|
|||
|
bd6e90de1b
|
|||
|
|
4404fb3df9 | ||
|
|
f68793fbdb | ||
|
|
3a69c3c788 | ||
|
e861a0a49a
|
|||
|
|
cb2cf572e0 | ||
|
494869a172
|
|||
|
|
e0bc93096d | ||
|
1ff94eb9d3
|
|||
|
|
97ab10edcd | ||
|
|
3ff7ff18bb | ||
|
0f50c8a0f0
|
|||
|
|
691fb78fda | ||
|
|
34ef4da317 | ||
|
|
8c287b3ce7 | ||
|
54f5479c24
|
|||
|
|
f467754df1 | ||
|
b57ce40b05
|
|||
|
5264631550
|
|||
|
a76f7c439d
|
|||
|
|
d14551781c | ||
|
|
577e087321 |
4
.github/workflows/build-release.yml
vendored
4
.github/workflows/build-release.yml
vendored
@@ -17,10 +17,10 @@ jobs:
|
||||
uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Set up Python
|
||||
run: uv python install 3.13
|
||||
run: uv python install 3.14
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync
|
||||
run: uv sync --group dev
|
||||
|
||||
- name: Build
|
||||
run: uv build
|
||||
|
||||
20
.github/workflows/ci.yml
vendored
20
.github/workflows/ci.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
run: uv python install 3.13
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --extra dev
|
||||
run: uv sync --group dev
|
||||
|
||||
- name: Run Ruff linter
|
||||
run: uv run ruff check .
|
||||
@@ -45,7 +45,7 @@ jobs:
|
||||
run: uv python install 3.13
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --extra dev
|
||||
run: uv sync --group dev
|
||||
|
||||
- name: Run ty
|
||||
run: uv run ty check
|
||||
@@ -56,7 +56,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.11", "3.12", "3.13"]
|
||||
python-version: ["3.11", "3.12", "3.13", "3.14"]
|
||||
|
||||
services:
|
||||
postgres:
|
||||
@@ -83,18 +83,26 @@ jobs:
|
||||
run: uv python install ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --extra dev
|
||||
run: uv sync --group dev
|
||||
|
||||
- name: Run tests with coverage
|
||||
env:
|
||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@localhost:5432/test_db
|
||||
run: |
|
||||
uv run pytest --cov --cov-report=xml --cov-report=term-missing
|
||||
uv run pytest --cov --cov-report=xml --cov-report=term-missing --junitxml=junit.xml -o junit_family=legacy
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
if: matrix.python-version == '3.13'
|
||||
if: matrix.python-version == '3.14'
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
report_type: coverage
|
||||
files: ./coverage.xml
|
||||
fail_ci_if_error: false
|
||||
|
||||
- name: Upload test results to Codecov
|
||||
if: matrix.python-version == '3.14'
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
report_type: test_results
|
||||
|
||||
38
.github/workflows/docs.yml
vendored
Normal file
38
.github/workflows/docs.yml
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
name: Documentation
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pages: write
|
||||
id-token: write
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
environment:
|
||||
name: github-pages
|
||||
url: ${{ steps.deployment.outputs.page_url }}
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/configure-pages@v5
|
||||
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Set up Python
|
||||
run: uv python install 3.13
|
||||
|
||||
- run: uv sync --group dev
|
||||
|
||||
- run: uv run zensical build --clean
|
||||
|
||||
- uses: actions/upload-pages-artifact@v4
|
||||
with:
|
||||
path: site
|
||||
|
||||
- uses: actions/deploy-pages@v4
|
||||
id: deployment
|
||||
@@ -1 +1 @@
|
||||
3.13
|
||||
3.14
|
||||
|
||||
37
README.md
37
README.md
@@ -1,6 +1,6 @@
|
||||
# FastAPI Toolsets
|
||||
|
||||
FastAPI Toolsets provides production-ready utilities for FastAPI applications built with async SQLAlchemy and PostgreSQL. It includes generic CRUD operations, a fixture system with dependency resolution, a Django-like CLI, standardized API responses, and structured exception handling with automatic OpenAPI documentation.
|
||||
A modular collection of production-ready utilities for FastAPI. Install only what you need — from async CRUD and database helpers to CLI tooling, Prometheus metrics, and pytest fixtures. Each module is independently installable via optional extras, keeping your dependency footprint minimal.
|
||||
|
||||
[](https://github.com/d3vyce/fastapi-toolsets/actions/workflows/ci.yml)
|
||||
[](https://codecov.io/gh/d3vyce/fastapi-toolsets)
|
||||
@@ -20,17 +20,44 @@ FastAPI Toolsets provides production-ready utilities for FastAPI applications bu
|
||||
|
||||
## Installation
|
||||
|
||||
The base package includes the core modules (CRUD, database, schemas, exceptions, fixtures, dependencies, model mixins, logging):
|
||||
|
||||
```bash
|
||||
uv add fastapi-toolsets
|
||||
```
|
||||
|
||||
Install only the extras you need:
|
||||
|
||||
```bash
|
||||
uv add "fastapi-toolsets[cli]"
|
||||
uv add "fastapi-toolsets[metrics]"
|
||||
uv add "fastapi-toolsets[pytest]"
|
||||
```
|
||||
|
||||
Or install everything:
|
||||
|
||||
```bash
|
||||
uv add "fastapi-toolsets[all]"
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
- **CRUD**: Generic async CRUD operations with `CrudFactory`
|
||||
- **Fixtures**: Fixture system with dependency management, context support and pytest integration
|
||||
- **CLI**: Django-like command-line interface for fixtures and custom commands
|
||||
- **Standardized API Responses**: Consistent response format across your API
|
||||
### Core
|
||||
|
||||
- **CRUD**: Generic async CRUD operations with `CrudFactory`, built-in full-text/faceted search and Offset/Cursor pagination.
|
||||
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
|
||||
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
|
||||
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration
|
||||
- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`) and lifecycle callbacks (`WatchedFieldsMixin`, `@watch`) that fire after commit for insert, update, and delete events
|
||||
- **Standardized API Responses**: Consistent response format with `Response`, `ErrorResponse`, `PaginatedResponse`, `CursorPaginatedResponse` and `OffsetPaginatedResponse`.
|
||||
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
|
||||
- **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger`
|
||||
|
||||
### Optional
|
||||
|
||||
- **CLI**: Django-like command-line interface with fixture management and custom commands support
|
||||
- **Metrics**: Prometheus metrics endpoint with provider/collector registry
|
||||
- **Pytest Helpers**: Async test client, database session management, `pytest-xdist` support, and table cleanup utilities
|
||||
|
||||
## License
|
||||
|
||||
|
||||
185
docs/examples/pagination-search.md
Normal file
185
docs/examples/pagination-search.md
Normal file
@@ -0,0 +1,185 @@
|
||||
# Pagination & search
|
||||
|
||||
This example builds an articles listing endpoint that supports **offset pagination**, **cursor pagination**, **full-text search**, **faceted filtering**, and **sorting** — all from a single `CrudFactory` definition.
|
||||
|
||||
## Models
|
||||
|
||||
```python title="models.py"
|
||||
--8<-- "docs_src/examples/pagination_search/models.py"
|
||||
```
|
||||
|
||||
## Schemas
|
||||
|
||||
```python title="schemas.py"
|
||||
--8<-- "docs_src/examples/pagination_search/schemas.py"
|
||||
```
|
||||
|
||||
## Crud
|
||||
|
||||
Declare `searchable_fields`, `facet_fields`, and `order_fields` once on [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory). All endpoints built from this class share the same defaults and can override them per call.
|
||||
|
||||
```python title="crud.py"
|
||||
--8<-- "docs_src/examples/pagination_search/crud.py"
|
||||
```
|
||||
|
||||
## Session dependency
|
||||
|
||||
```python title="db.py"
|
||||
--8<-- "docs_src/examples/pagination_search/db.py"
|
||||
```
|
||||
|
||||
!!! info "Deploy a Postgres DB with docker"
|
||||
```bash
|
||||
docker run -d --name postgres -e POSTGRES_USER=postgres -e POSTGRES_PASSWORD=postgres -e POSTGRES_DB=postgres -p 5432:5432 postgres:18-alpine
|
||||
```
|
||||
|
||||
|
||||
## App
|
||||
|
||||
```python title="app.py"
|
||||
--8<-- "docs_src/examples/pagination_search/app.py"
|
||||
```
|
||||
|
||||
|
||||
## Routes
|
||||
|
||||
```python title="routes.py:1:17"
|
||||
--8<-- "docs_src/examples/pagination_search/routes.py:1:17"
|
||||
```
|
||||
|
||||
### Offset pagination
|
||||
|
||||
Best for admin panels or any UI that needs a total item count and numbered pages.
|
||||
|
||||
```python title="routes.py:20:40"
|
||||
--8<-- "docs_src/examples/pagination_search/routes.py:20:40"
|
||||
```
|
||||
|
||||
**Example request**
|
||||
|
||||
```
|
||||
GET /articles/offset?page=2&items_per_page=10&search=fastapi&status=published&order_by=title&order=asc
|
||||
```
|
||||
|
||||
**Example response**
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"pagination_type": "offset",
|
||||
"data": [
|
||||
{ "id": "3f47ac69-...", "title": "FastAPI tips", "status": "published", ... }
|
||||
],
|
||||
"pagination": {
|
||||
"total_count": 42,
|
||||
"pages": 5,
|
||||
"page": 2,
|
||||
"items_per_page": 10,
|
||||
"has_more": true
|
||||
},
|
||||
"filter_attributes": {
|
||||
"status": ["archived", "draft", "published"],
|
||||
"name": ["backend", "frontend", "python"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`filter_attributes` always reflects the values visible **after** applying the active filters. Use it to populate filter dropdowns on the client.
|
||||
|
||||
To skip the `COUNT(*)` query for better performance on large tables, pass `include_total=False`. `pagination.total_count` will be `null` in the response, while `has_more` remains accurate.
|
||||
|
||||
### Cursor pagination
|
||||
|
||||
Best for feeds, infinite scroll, or any high-throughput API where offset performance degrades.
|
||||
|
||||
```python title="routes.py:43:63"
|
||||
--8<-- "docs_src/examples/pagination_search/routes.py:43:63"
|
||||
```
|
||||
|
||||
**Example request**
|
||||
|
||||
```
|
||||
GET /articles/cursor?items_per_page=10&status=published&order_by=created_at&order=desc
|
||||
```
|
||||
|
||||
**Example response**
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"pagination_type": "cursor",
|
||||
"data": [
|
||||
{ "id": "3f47ac69-...", "title": "FastAPI tips", "status": "published", ... }
|
||||
],
|
||||
"pagination": {
|
||||
"next_cursor": "eyJ2YWx1ZSI6ICIzZjQ3YWM2OS0uLi4ifQ==",
|
||||
"prev_cursor": null,
|
||||
"items_per_page": 10,
|
||||
"has_more": true
|
||||
},
|
||||
"filter_attributes": {
|
||||
"status": ["published"],
|
||||
"name": ["backend", "python"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Pass `next_cursor` as the `cursor` query parameter on the next request to advance to the next page.
|
||||
|
||||
### Unified endpoint (both strategies)
|
||||
|
||||
!!! info "Added in `v2.3.0`"
|
||||
|
||||
[`paginate()`](../module/crud.md#unified-paginate--both-strategies-on-one-endpoint) lets a single endpoint support both strategies via a `pagination_type` query parameter. The `pagination_type` field in the response acts as a discriminator for frontend tooling.
|
||||
|
||||
```python title="routes.py:66:90"
|
||||
--8<-- "docs_src/examples/pagination_search/routes.py:66:90"
|
||||
```
|
||||
|
||||
**Offset request** (default)
|
||||
|
||||
```
|
||||
GET /articles/?pagination_type=offset&page=1&items_per_page=10
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"pagination_type": "offset",
|
||||
"data": ["..."],
|
||||
"pagination": { "total_count": 42, "pages": 5, "page": 1, "items_per_page": 10, "has_more": true }
|
||||
}
|
||||
```
|
||||
|
||||
**Cursor request**
|
||||
|
||||
```
|
||||
GET /articles/?pagination_type=cursor&items_per_page=10
|
||||
GET /articles/?pagination_type=cursor&items_per_page=10&cursor=eyJ2YWx1ZSI6...
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"pagination_type": "cursor",
|
||||
"data": ["..."],
|
||||
"pagination": { "next_cursor": "eyJ2YWx1ZSI6...", "prev_cursor": null, "items_per_page": 10, "has_more": true }
|
||||
}
|
||||
```
|
||||
|
||||
## Search behaviour
|
||||
|
||||
Both endpoints inherit the same `searchable_fields` declared on `ArticleCrud`:
|
||||
|
||||
Search is **case-insensitive** and uses a `LIKE %query%` pattern. Pass a [`SearchConfig`](../reference/crud.md#fastapi_toolsets.crud.search.SearchConfig) instead of a plain string to control case sensitivity or switch to `match_mode="all"` (AND across all fields instead of OR).
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.crud import SearchConfig
|
||||
|
||||
# Both title AND body must contain "fastapi"
|
||||
result = await ArticleCrud.offset_paginate(
|
||||
session,
|
||||
search=SearchConfig(query="fastapi", case_sensitive=True, match_mode="all"),
|
||||
search_fields=[Article.title, Article.body],
|
||||
)
|
||||
```
|
||||
68
docs/index.md
Normal file
68
docs/index.md
Normal file
@@ -0,0 +1,68 @@
|
||||
# FastAPI Toolsets
|
||||
|
||||
A modular collection of production-ready utilities for FastAPI. Install only what you need — from async CRUD and database helpers to CLI tooling, Prometheus metrics, and pytest fixtures. Each module is independently installable via optional extras, keeping your dependency footprint minimal.
|
||||
|
||||
[](https://github.com/d3vyce/fastapi-toolsets/actions/workflows/ci.yml)
|
||||
[](https://codecov.io/gh/d3vyce/fastapi-toolsets)
|
||||
[](https://github.com/astral-sh/ty)
|
||||
[](https://github.com/astral-sh/uv)
|
||||
[](https://github.com/astral-sh/ruff)
|
||||
[](https://www.python.org/downloads/)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
|
||||
---
|
||||
|
||||
**Documentation**: [https://fastapi-toolsets.d3vyce.fr](https://fastapi-toolsets.d3vyce.fr)
|
||||
|
||||
**Source Code**: [https://github.com/d3vyce/fastapi-toolsets](https://github.com/d3vyce/fastapi-toolsets)
|
||||
|
||||
---
|
||||
|
||||
## Installation
|
||||
|
||||
The base package includes the core modules (CRUD, database, schemas, exceptions, fixtures, dependencies, model mixins, logging):
|
||||
|
||||
```bash
|
||||
uv add fastapi-toolsets
|
||||
```
|
||||
|
||||
Install only the extras you need:
|
||||
|
||||
```bash
|
||||
uv add "fastapi-toolsets[cli]"
|
||||
uv add "fastapi-toolsets[metrics]"
|
||||
uv add "fastapi-toolsets[pytest]"
|
||||
```
|
||||
|
||||
Or install everything:
|
||||
|
||||
```bash
|
||||
uv add "fastapi-toolsets[all]"
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
### Core
|
||||
|
||||
- **CRUD**: Generic async CRUD operations with `CrudFactory`, built-in full-text/faceted search and Offset/Cursor pagination.
|
||||
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
|
||||
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
|
||||
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration
|
||||
- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`) and lifecycle callbacks (`WatchedFieldsMixin`) that fire after commit for insert, update, and delete events.
|
||||
- **Standardized API Responses**: Consistent response format with `Response`, `ErrorResponse`, `PaginatedResponse`, `CursorPaginatedResponse` and `OffsetPaginatedResponse`.
|
||||
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
|
||||
- **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger`
|
||||
|
||||
### Optional
|
||||
|
||||
- **CLI**: Django-like command-line interface with fixture management and custom commands support
|
||||
- **Metrics**: Prometheus metrics endpoint with provider/collector registry
|
||||
- **Pytest Helpers**: Async test client, database session management, `pytest-xdist` support, and table cleanup utilities
|
||||
|
||||
## License
|
||||
|
||||
MIT License - see [LICENSE](LICENSE) for details.
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome! Please feel free to submit issues and pull requests.
|
||||
137
docs/migration/v2.md
Normal file
137
docs/migration/v2.md
Normal file
@@ -0,0 +1,137 @@
|
||||
# Migrating to v2.0
|
||||
|
||||
This page covers every breaking change introduced in **v2.0** and the steps required to update your code.
|
||||
|
||||
---
|
||||
|
||||
## CRUD
|
||||
|
||||
### `schema` is now required in `offset_paginate()` and `cursor_paginate()`
|
||||
|
||||
Calls that omit `schema` will now raise a `TypeError` at runtime.
|
||||
|
||||
Previously `schema` was optional; omitting it returned raw SQLAlchemy model instances inside the response. It is now a required keyword argument and the response always contains serialized schema instances.
|
||||
|
||||
=== "Before (`v1`)"
|
||||
|
||||
```python
|
||||
# schema omitted — returned raw model instances
|
||||
result = await UserCrud.offset_paginate(session=session, page=1)
|
||||
result = await UserCrud.cursor_paginate(session=session, cursor=token)
|
||||
```
|
||||
|
||||
=== "Now (`v2`)"
|
||||
|
||||
```python
|
||||
result = await UserCrud.offset_paginate(session=session, page=1, schema=UserRead)
|
||||
result = await UserCrud.cursor_paginate(session=session, cursor=token, schema=UserRead)
|
||||
```
|
||||
|
||||
### `as_response` removed from `create()`, `get()`, and `update()`
|
||||
|
||||
Passing `as_response` to these methods will raise a `TypeError` at runtime.
|
||||
|
||||
The `as_response=True` shorthand is replaced by passing a `schema` directly. The return value is a `Response[schema]` when `schema` is provided, or the raw model instance when it is not.
|
||||
|
||||
=== "Before (`v1`)"
|
||||
|
||||
```python
|
||||
user = await UserCrud.create(session=session, obj=data, as_response=True)
|
||||
user = await UserCrud.get(session=session, filters=filters, as_response=True)
|
||||
user = await UserCrud.update(session=session, obj=data, filters, as_response=True)
|
||||
```
|
||||
|
||||
=== "Now (`v2`)"
|
||||
|
||||
```python
|
||||
user = await UserCrud.create(session=session, obj=data, schema=UserRead)
|
||||
user = await UserCrud.get(session=session, filters=filters, schema=UserRead)
|
||||
user = await UserCrud.update(session=session, obj=data, filters, schema=UserRead)
|
||||
```
|
||||
|
||||
### `delete()`: `as_response` renamed and return type changed
|
||||
|
||||
`as_response` is gone, and the plain (non-response) call no longer returns `True`.
|
||||
|
||||
Two changes were made to `delete()`:
|
||||
|
||||
1. The `as_response` parameter is renamed to `return_response`.
|
||||
2. When called without `return_response=True`, the method now returns `None` on success instead of `True`.
|
||||
|
||||
=== "Before (`v1`)"
|
||||
|
||||
```python
|
||||
ok = await UserCrud.delete(session=session, filters=filters)
|
||||
if ok: # True on success
|
||||
...
|
||||
|
||||
response = await UserCrud.delete(session=session, filters=filters, as_response=True)
|
||||
```
|
||||
|
||||
=== "Now (`v2`)"
|
||||
|
||||
```python
|
||||
await UserCrud.delete(session=session, filters=filters) # returns None
|
||||
|
||||
response = await UserCrud.delete(session=session, filters=filters, return_response=True)
|
||||
```
|
||||
|
||||
### `paginate()` alias removed
|
||||
|
||||
Any call to `crud.paginate(...)` will raise `AttributeError` at runtime.
|
||||
|
||||
The `paginate` shorthand was an alias for `offset_paginate`. It has been removed; call `offset_paginate` directly.
|
||||
|
||||
=== "Before (`v1`)"
|
||||
|
||||
```python
|
||||
result = await UserCrud.paginate(session=session, page=2, items_per_page=20, schema=UserRead)
|
||||
```
|
||||
|
||||
=== "Now (`v2`)"
|
||||
|
||||
```python
|
||||
result = await UserCrud.offset_paginate(session=session, page=2, items_per_page=20, schema=UserRead)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Exceptions
|
||||
|
||||
### Missing `api_error` raises `TypeError` at class definition time
|
||||
|
||||
Unfinished or stub exception subclasses that previously compiled fine will now fail on import.
|
||||
|
||||
In `v1`, a subclass without `api_error` would only fail when the exception was raised. In `v2`, `__init_subclass__` validates this at class definition time.
|
||||
|
||||
=== "Before (`v1`)"
|
||||
|
||||
```python
|
||||
class MyError(ApiException):
|
||||
pass # fine until raised
|
||||
```
|
||||
|
||||
=== "Now (`v2`)"
|
||||
|
||||
```python
|
||||
class MyError(ApiException):
|
||||
pass # TypeError: MyError must define an 'api_error' class attribute.
|
||||
```
|
||||
|
||||
For shared base classes that are not meant to be raised directly, use `abstract=True`:
|
||||
|
||||
```python
|
||||
class BillingError(ApiException, abstract=True):
|
||||
"""Base for all billing-related errors — not raised directly."""
|
||||
|
||||
class PaymentRequiredError(BillingError):
|
||||
api_error = ApiError(code=402, msg="Payment Required", desc="...", err_code="BILLING-402")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Schemas
|
||||
|
||||
### `Pagination` alias removed
|
||||
|
||||
`Pagination` was already deprecated in `v1` and is fully removed in `v2`, you now need to use [`OffsetPagination`](../reference/schemas.md#fastapi_toolsets.schemas.OffsetPagination) or [`CursorPagination`](../reference/schemas.md#fastapi_toolsets.schemas.CursorPagination).
|
||||
93
docs/module/cli.md
Normal file
93
docs/module/cli.md
Normal file
@@ -0,0 +1,93 @@
|
||||
# CLI
|
||||
|
||||
Typer-based command-line interface for managing your FastAPI application, with built-in fixture commands integration.
|
||||
|
||||
## Installation
|
||||
|
||||
=== "uv"
|
||||
``` bash
|
||||
uv add "fastapi-toolsets[cli]"
|
||||
```
|
||||
|
||||
=== "pip"
|
||||
``` bash
|
||||
pip install "fastapi-toolsets[cli]"
|
||||
```
|
||||
|
||||
## Overview
|
||||
|
||||
The `cli` module provides a `manager` entry point built with [Typer](https://typer.tiangolo.com/). It allow custom commands to be added in addition of the fixture commands when a [`FixtureRegistry`](../reference/fixtures.md#fastapi_toolsets.fixtures.registry.FixtureRegistry) and a database context are configured.
|
||||
|
||||
## Configuration
|
||||
|
||||
Configure the CLI in your `pyproject.toml`:
|
||||
|
||||
```toml
|
||||
[tool.fastapi-toolsets]
|
||||
cli = "myapp.cli:cli" # Custom Typer app
|
||||
fixtures = "myapp.fixtures:registry" # FixtureRegistry instance
|
||||
db_context = "myapp.db:db_context" # Async context manager for sessions
|
||||
```
|
||||
|
||||
All fields are optional. Without configuration, the `manager` command still works but no command are available.
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Manager commands
|
||||
manager --help
|
||||
|
||||
Usage: manager [OPTIONS] COMMAND [ARGS]...
|
||||
|
||||
FastAPI utilities CLI.
|
||||
|
||||
╭─ Options ────────────────────────────────────────────────────────────────────────╮
|
||||
│ --install-completion Install completion for the current shell. │
|
||||
│ --show-completion Show completion for the current shell, to copy it │
|
||||
│ or customize the installation. │
|
||||
│ --help Show this message and exit. │
|
||||
╰──────────────────────────────────────────────────────────────────────────────────╯
|
||||
╭─ Commands ───────────────────────────────────────────────────────────────────────╮
|
||||
│ check-db │
|
||||
│ fixtures Manage database fixtures. │
|
||||
╰──────────────────────────────────────────────────────────────────────────────────╯
|
||||
|
||||
# Fixtures commands
|
||||
manager fixtures --help
|
||||
|
||||
Usage: manager fixtures [OPTIONS] COMMAND [ARGS]...
|
||||
|
||||
Manage database fixtures.
|
||||
|
||||
╭─ Options ────────────────────────────────────────────────────────────────────────╮
|
||||
│ --help Show this message and exit. │
|
||||
╰──────────────────────────────────────────────────────────────────────────────────╯
|
||||
╭─ Commands ───────────────────────────────────────────────────────────────────────╮
|
||||
│ list List all registered fixtures. │
|
||||
│ load Load fixtures into the database. │
|
||||
╰──────────────────────────────────────────────────────────────────────────────────╯
|
||||
```
|
||||
|
||||
## Custom CLI
|
||||
|
||||
You can extend the CLI by providing your own Typer app. The `manager` entry point will merge your app's commands with the built-in ones:
|
||||
|
||||
```python
|
||||
# myapp/cli.py
|
||||
import typer
|
||||
|
||||
cli = typer.Typer()
|
||||
|
||||
@cli.command()
|
||||
def hello():
|
||||
print("Hello from my app!")
|
||||
```
|
||||
|
||||
```toml
|
||||
[tool.fastapi-toolsets]
|
||||
cli = "myapp.cli:cli"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
[:material-api: API Reference](../reference/cli.md)
|
||||
671
docs/module/crud.md
Normal file
671
docs/module/crud.md
Normal file
@@ -0,0 +1,671 @@
|
||||
# CRUD
|
||||
|
||||
Generic async CRUD operations for SQLAlchemy models with search, pagination, and many-to-many support.
|
||||
|
||||
!!! info
|
||||
This module has been coded and tested to be compatible with PostgreSQL only.
|
||||
|
||||
## Overview
|
||||
|
||||
The `crud` module provides [`AsyncCrud`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud), a base class with a full suite of async database operations, and [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory), a convenience function to instantiate it for a given model.
|
||||
|
||||
## Creating a CRUD class
|
||||
|
||||
### Factory style
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.crud import CrudFactory
|
||||
from myapp.models import User
|
||||
|
||||
UserCrud = CrudFactory(model=User)
|
||||
```
|
||||
|
||||
[`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory) dynamically creates a class named `AsyncUserCrud` with `User` as its model. This is the most concise option for straightforward CRUD with no custom logic.
|
||||
|
||||
### Subclass style
|
||||
|
||||
!!! info "Added in `v2.3.0`"
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.crud.factory import AsyncCrud
|
||||
from myapp.models import User
|
||||
|
||||
class UserCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
searchable_fields = [User.username, User.email]
|
||||
default_load_options = [selectinload(User.role)]
|
||||
```
|
||||
|
||||
Subclassing [`AsyncCrud`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud) directly is the preferred style when you need to add custom methods or when the configuration is complex enough to benefit from a named class body.
|
||||
|
||||
### Adding custom methods
|
||||
|
||||
```python
|
||||
class UserCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
|
||||
@classmethod
|
||||
async def get_active(cls, session: AsyncSession) -> list[User]:
|
||||
return await cls.get_multi(session, filters=[User.is_active == True])
|
||||
```
|
||||
|
||||
### Sharing a custom base across multiple models
|
||||
|
||||
Define a generic base class with the shared methods, then subclass it for each model:
|
||||
|
||||
```python
|
||||
from typing import Generic, TypeVar
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from fastapi_toolsets.crud.factory import AsyncCrud
|
||||
|
||||
T = TypeVar("T", bound=DeclarativeBase)
|
||||
|
||||
class AuditedCrud(AsyncCrud[T], Generic[T]):
|
||||
"""Base CRUD with custom function"""
|
||||
|
||||
@classmethod
|
||||
async def get_active(cls, session: AsyncSession):
|
||||
return await cls.get_multi(session, filters=[cls.model.is_active == True])
|
||||
|
||||
|
||||
class UserCrud(AuditedCrud[User]):
|
||||
model = User
|
||||
searchable_fields = [User.username, User.email]
|
||||
```
|
||||
|
||||
You can also use the factory shorthand with the same base by passing `base_class`:
|
||||
|
||||
```python
|
||||
UserCrud = CrudFactory(User, base_class=AuditedCrud)
|
||||
```
|
||||
|
||||
## Basic operations
|
||||
|
||||
!!! info "`get_or_none` added in `v2.2`"
|
||||
|
||||
```python
|
||||
# Create
|
||||
user = await UserCrud.create(session=session, obj=UserCreateSchema(username="alice"))
|
||||
|
||||
# Get one (raises NotFoundError if not found)
|
||||
user = await UserCrud.get(session=session, filters=[User.id == user_id])
|
||||
|
||||
# Get one or None (never raises)
|
||||
user = await UserCrud.get_or_none(session=session, filters=[User.id == user_id])
|
||||
|
||||
# Get first or None
|
||||
user = await UserCrud.first(session=session, filters=[User.email == email])
|
||||
|
||||
# Get multiple
|
||||
users = await UserCrud.get_multi(session=session, filters=[User.is_active == True])
|
||||
|
||||
# Update
|
||||
user = await UserCrud.update(session=session, obj=UserUpdateSchema(username="bob"), filters=[User.id == user_id])
|
||||
|
||||
# Delete
|
||||
await UserCrud.delete(session=session, filters=[User.id == user_id])
|
||||
|
||||
# Count / exists
|
||||
count = await UserCrud.count(session=session, filters=[User.is_active == True])
|
||||
exists = await UserCrud.exists(session=session, filters=[User.email == email])
|
||||
```
|
||||
|
||||
## Fetching a single record
|
||||
|
||||
Three methods fetch a single record — choose based on how you want to handle the "not found" case and whether you need strict uniqueness:
|
||||
|
||||
| Method | Not found | Multiple results |
|
||||
|---|---|---|
|
||||
| `get` | raises `NotFoundError` | raises `MultipleResultsFound` |
|
||||
| `get_or_none` | returns `None` | raises `MultipleResultsFound` |
|
||||
| `first` | returns `None` | returns the first match silently |
|
||||
|
||||
Use `get` when the record must exist (e.g. a detail endpoint that should return 404):
|
||||
|
||||
```python
|
||||
user = await UserCrud.get(session=session, filters=[User.id == user_id])
|
||||
```
|
||||
|
||||
Use `get_or_none` when the record may not exist but you still want strict uniqueness enforcement:
|
||||
|
||||
```python
|
||||
user = await UserCrud.get_or_none(session=session, filters=[User.email == email])
|
||||
if user is None:
|
||||
... # handle missing case without catching an exception
|
||||
```
|
||||
|
||||
Use `first` when you only care about any one match and don't need uniqueness:
|
||||
|
||||
```python
|
||||
user = await UserCrud.first(session=session, filters=[User.is_active == True])
|
||||
```
|
||||
|
||||
## Pagination
|
||||
|
||||
!!! info "Added in `v1.1` (only offset_pagination via `paginate` if `<v1.1`)"
|
||||
|
||||
Three pagination methods are available. All return a typed response whose `pagination_type` field tells clients which strategy was used.
|
||||
|
||||
| | `offset_paginate` | `cursor_paginate` | `paginate` |
|
||||
|---|---|---|---|
|
||||
| Return type | `OffsetPaginatedResponse` | `CursorPaginatedResponse` | either, based on `pagination_type` param |
|
||||
| Total count | Yes | No | / |
|
||||
| Jump to arbitrary page | Yes | No | / |
|
||||
| Performance on deep pages | Degrades | Constant | / |
|
||||
| Stable under concurrent inserts | No | Yes | / |
|
||||
| Use case | Admin panels, numbered pagination | Feeds, APIs, infinite scroll | single endpoint, both strategies |
|
||||
|
||||
### Offset pagination
|
||||
|
||||
```python
|
||||
@router.get("")
|
||||
async def get_users(
|
||||
session: SessionDep,
|
||||
items_per_page: int = 50,
|
||||
page: int = 1,
|
||||
) -> OffsetPaginatedResponse[UserRead]:
|
||||
return await UserCrud.offset_paginate(
|
||||
session=session,
|
||||
items_per_page=items_per_page,
|
||||
page=page,
|
||||
schema=UserRead,
|
||||
)
|
||||
```
|
||||
|
||||
The [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) method returns an [`OffsetPaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.OffsetPaginatedResponse):
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"pagination_type": "offset",
|
||||
"data": ["..."],
|
||||
"pagination": {
|
||||
"total_count": 100,
|
||||
"pages": 5,
|
||||
"page": 1,
|
||||
"items_per_page": 20,
|
||||
"has_more": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Skipping the COUNT query
|
||||
|
||||
!!! info "Added in `v2.4.1`"
|
||||
|
||||
By default `offset_paginate` runs two queries: one for the page items and one `COUNT(*)` for `total_count`. On large tables the `COUNT` can be expensive. Pass `include_total=False` to skip it:
|
||||
|
||||
```python
|
||||
result = await UserCrud.offset_paginate(
|
||||
session=session,
|
||||
page=page,
|
||||
items_per_page=items_per_page,
|
||||
include_total=False,
|
||||
schema=UserRead,
|
||||
)
|
||||
```
|
||||
|
||||
#### Pagination params dependency
|
||||
|
||||
!!! info "Added in `v2.4.1`"
|
||||
|
||||
Use [`offset_params()`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_params) to generate a FastAPI dependency that injects `page` and `items_per_page` from query parameters with configurable defaults and a `max_page_size` cap:
|
||||
|
||||
```python
|
||||
from typing import Annotated
|
||||
from fastapi import Depends
|
||||
|
||||
@router.get("")
|
||||
async def list_users(
|
||||
session: SessionDep,
|
||||
params: Annotated[dict, Depends(UserCrud.offset_params(default_page_size=20, max_page_size=100))],
|
||||
) -> OffsetPaginatedResponse[UserRead]:
|
||||
return await UserCrud.offset_paginate(session=session, **params, schema=UserRead)
|
||||
```
|
||||
|
||||
### Cursor pagination
|
||||
|
||||
```python
|
||||
@router.get("")
|
||||
async def list_users(
|
||||
session: SessionDep,
|
||||
cursor: str | None = None,
|
||||
items_per_page: int = 20,
|
||||
) -> CursorPaginatedResponse[UserRead]:
|
||||
return await UserCrud.cursor_paginate(
|
||||
session=session,
|
||||
cursor=cursor,
|
||||
items_per_page=items_per_page,
|
||||
schema=UserRead,
|
||||
)
|
||||
```
|
||||
|
||||
The [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate) method returns a [`CursorPaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.CursorPaginatedResponse):
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"pagination_type": "cursor",
|
||||
"data": ["..."],
|
||||
"pagination": {
|
||||
"next_cursor": "eyJ2YWx1ZSI6ICIzZjQ3YWM2OS0uLi4ifQ==",
|
||||
"prev_cursor": null,
|
||||
"items_per_page": 20,
|
||||
"has_more": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Pass `next_cursor` as the `cursor` query parameter on the next request to advance to the next page. `prev_cursor` is set on pages 2+ and points back to the first item of the current page. Both are `null` when there is no adjacent page.
|
||||
|
||||
#### Choosing a cursor column
|
||||
|
||||
The cursor column is set once on [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory) via the `cursor_column` parameter. It must be monotonically ordered for stable results:
|
||||
|
||||
- Auto-increment integer PKs
|
||||
- UUID v7 PKs
|
||||
- Timestamps
|
||||
|
||||
!!! warning
|
||||
Random UUID v4 PKs are **not** suitable as cursor columns because their ordering is non-deterministic.
|
||||
|
||||
!!! note
|
||||
`cursor_column` is required. Calling [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate) on a CRUD class that has no `cursor_column` configured raises a `ValueError`.
|
||||
|
||||
The cursor value is URL-safe base64-encoded (no padding) when returned to the client and decoded back to the correct Python type on the next request. The following SQLAlchemy column types are supported:
|
||||
|
||||
| SQLAlchemy type | Python type |
|
||||
|---|---|
|
||||
| `Integer`, `BigInteger`, `SmallInteger` | `int` |
|
||||
| `Uuid` | `uuid.UUID` |
|
||||
| `DateTime` | `datetime.datetime` |
|
||||
| `Date` | `datetime.date` |
|
||||
| `Float`, `Numeric` | `decimal.Decimal` |
|
||||
|
||||
```python
|
||||
# Paginate by the primary key
|
||||
PostCrud = CrudFactory(model=Post, cursor_column=Post.id)
|
||||
|
||||
# Paginate by a timestamp column instead
|
||||
PostCrud = CrudFactory(model=Post, cursor_column=Post.created_at)
|
||||
```
|
||||
|
||||
#### Pagination params dependency
|
||||
|
||||
!!! info "Added in `v2.4.1`"
|
||||
|
||||
Use [`cursor_params()`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_params) to inject `cursor` and `items_per_page` from query parameters with a `max_page_size` cap:
|
||||
|
||||
```python
|
||||
from typing import Annotated
|
||||
from fastapi import Depends
|
||||
|
||||
@router.get("")
|
||||
async def list_users(
|
||||
session: SessionDep,
|
||||
params: Annotated[dict, Depends(UserCrud.cursor_params(default_page_size=20, max_page_size=100))],
|
||||
) -> CursorPaginatedResponse[UserRead]:
|
||||
return await UserCrud.cursor_paginate(session=session, **params, schema=UserRead)
|
||||
```
|
||||
|
||||
### Unified endpoint (both strategies)
|
||||
|
||||
!!! info "Added in `v2.3.0`"
|
||||
|
||||
[`paginate()`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.paginate) dispatches to `offset_paginate` or `cursor_paginate` based on a `pagination_type` query parameter, letting you expose **one endpoint** that supports both strategies. The `pagination_type` field in the response tells clients which strategy was used, enabling frontend discriminated-union typing.
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.crud import PaginationType
|
||||
from fastapi_toolsets.schemas import PaginatedResponse
|
||||
|
||||
@router.get("")
|
||||
async def list_users(
|
||||
session: SessionDep,
|
||||
pagination_type: PaginationType = PaginationType.OFFSET,
|
||||
page: int = Query(1, ge=1, description="Current page (offset only)"),
|
||||
cursor: str | None = Query(None, description="Cursor token (cursor only)"),
|
||||
items_per_page: int = Query(20, ge=1, le=100),
|
||||
) -> PaginatedResponse[UserRead]:
|
||||
return await UserCrud.paginate(
|
||||
session,
|
||||
pagination_type=pagination_type,
|
||||
page=page,
|
||||
cursor=cursor,
|
||||
items_per_page=items_per_page,
|
||||
schema=UserRead,
|
||||
)
|
||||
```
|
||||
|
||||
```
|
||||
GET /users?pagination_type=offset&page=2&items_per_page=10
|
||||
GET /users?pagination_type=cursor&cursor=eyJ2YWx1ZSI6...&items_per_page=10
|
||||
```
|
||||
|
||||
#### Pagination params dependency
|
||||
|
||||
!!! info "Added in `v2.4.1`"
|
||||
|
||||
Use [`paginate_params()`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.paginate_params) to inject all parameters at once with configurable defaults and a `max_page_size` cap:
|
||||
|
||||
```python
|
||||
from typing import Annotated
|
||||
from fastapi import Depends
|
||||
from fastapi_toolsets.schemas import PaginatedResponse
|
||||
|
||||
@router.get("")
|
||||
async def list_users(
|
||||
session: SessionDep,
|
||||
params: Annotated[dict, Depends(UserCrud.paginate_params(default_page_size=20, max_page_size=100))],
|
||||
) -> PaginatedResponse[UserRead]:
|
||||
return await UserCrud.paginate(session, **params, schema=UserRead)
|
||||
```
|
||||
|
||||
## Search
|
||||
|
||||
Two search strategies are available, both compatible with [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) and [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate).
|
||||
|
||||
| | Full-text search | Faceted search |
|
||||
|---|---|---|
|
||||
| Input | Free-text string | Exact column values |
|
||||
| Relationship support | Yes | Yes |
|
||||
| Use case | Search bars | Filter dropdowns |
|
||||
|
||||
!!! info "You can use both search strategies in the same endpoint!"
|
||||
|
||||
### Full-text search
|
||||
|
||||
!!! info "Added in `v2.2.1`"
|
||||
The model's primary key is always included in `searchable_fields` automatically, so searching by ID works out of the box without any configuration. When no `searchable_fields` are declared, only the primary key is searched.
|
||||
|
||||
Declare `searchable_fields` on the CRUD class. Relationship traversal is supported via tuples:
|
||||
|
||||
```python
|
||||
PostCrud = CrudFactory(
|
||||
model=Post,
|
||||
searchable_fields=[
|
||||
Post.title,
|
||||
Post.content,
|
||||
(Post.author, User.username), # search across relationship
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
You can override `searchable_fields` per call with `search_fields`:
|
||||
|
||||
```python
|
||||
result = await UserCrud.offset_paginate(
|
||||
session=session,
|
||||
search_fields=[User.country],
|
||||
)
|
||||
```
|
||||
|
||||
This allows searching with both [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) and [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate):
|
||||
|
||||
```python
|
||||
@router.get("")
|
||||
async def get_users(
|
||||
session: SessionDep,
|
||||
items_per_page: int = 50,
|
||||
page: int = 1,
|
||||
search: str | None = None,
|
||||
) -> OffsetPaginatedResponse[UserRead]:
|
||||
return await UserCrud.offset_paginate(
|
||||
session=session,
|
||||
items_per_page=items_per_page,
|
||||
page=page,
|
||||
search=search,
|
||||
schema=UserRead,
|
||||
)
|
||||
```
|
||||
|
||||
```python
|
||||
@router.get("")
|
||||
async def get_users(
|
||||
session: SessionDep,
|
||||
cursor: str | None = None,
|
||||
items_per_page: int = 50,
|
||||
search: str | None = None,
|
||||
) -> CursorPaginatedResponse[UserRead]:
|
||||
return await UserCrud.cursor_paginate(
|
||||
session=session,
|
||||
items_per_page=items_per_page,
|
||||
cursor=cursor,
|
||||
search=search,
|
||||
schema=UserRead,
|
||||
)
|
||||
```
|
||||
|
||||
### Faceted search
|
||||
|
||||
!!! info "Added in `v1.2`"
|
||||
|
||||
Declare `facet_fields` on the CRUD class to return distinct column values alongside paginated results. This is useful for populating filter dropdowns or building faceted search UIs.
|
||||
|
||||
Facet fields use the same syntax as `searchable_fields` — direct columns or relationship tuples:
|
||||
|
||||
```python
|
||||
UserCrud = CrudFactory(
|
||||
model=User,
|
||||
facet_fields=[
|
||||
User.status,
|
||||
User.country,
|
||||
(User.role, Role.name), # value from a related model
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
You can override `facet_fields` per call:
|
||||
|
||||
```python
|
||||
result = await UserCrud.offset_paginate(
|
||||
session=session,
|
||||
facet_fields=[User.country],
|
||||
)
|
||||
```
|
||||
|
||||
The distinct values are returned in the `filter_attributes` field of [`PaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse):
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"data": ["..."],
|
||||
"pagination": { "..." },
|
||||
"filter_attributes": {
|
||||
"status": ["active", "inactive"],
|
||||
"country": ["DE", "FR", "US"],
|
||||
"name": ["admin", "editor", "viewer"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Use `filter_by` to pass the client's chosen filter values directly — no need to build SQLAlchemy conditions by hand. Any unknown key raises [`InvalidFacetFilterError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError).
|
||||
|
||||
!!! info "The keys in `filter_by` are the same keys the client received in `filter_attributes`."
|
||||
Keys are normally the terminal `column.key` (e.g. `"name"` for `Role.name`). When two facet fields share the same column key (e.g. `(Build.project, Project.name)` and `(Build.os, Os.name)`), the relationship name is prepended automatically: `"project__name"` and `"os__name"`.
|
||||
|
||||
`filter_by` and `filters` can be combined — both are applied with AND logic.
|
||||
|
||||
Use [`filter_params()`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.filter_params) to generate a dict with the facet filter values from the query parameters:
|
||||
|
||||
```python
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
|
||||
UserCrud = CrudFactory(
|
||||
model=User,
|
||||
facet_fields=[User.status, User.country, (User.role, Role.name)],
|
||||
)
|
||||
|
||||
@router.get("", response_model_exclude_none=True)
|
||||
async def list_users(
|
||||
session: SessionDep,
|
||||
page: int = 1,
|
||||
filter_by: Annotated[dict[str, list[str]], Depends(UserCrud.filter_params())],
|
||||
) -> OffsetPaginatedResponse[UserRead]:
|
||||
return await UserCrud.offset_paginate(
|
||||
session=session,
|
||||
page=page,
|
||||
filter_by=filter_by,
|
||||
schema=UserRead,
|
||||
)
|
||||
```
|
||||
|
||||
Both single-value and multi-value query parameters work:
|
||||
|
||||
```
|
||||
GET /users?status=active → filter_by={"status": ["active"]}
|
||||
GET /users?status=active&country=FR → filter_by={"status": ["active"], "country": ["FR"]}
|
||||
GET /users?role=admin&role=editor → filter_by={"role": ["admin", "editor"]} (IN clause)
|
||||
```
|
||||
|
||||
## Sorting
|
||||
|
||||
!!! info "Added in `v1.3`"
|
||||
|
||||
Declare `order_fields` on the CRUD class to expose client-driven column ordering via `order_by` and `order` query parameters.
|
||||
|
||||
```python
|
||||
UserCrud = CrudFactory(
|
||||
model=User,
|
||||
order_fields=[
|
||||
User.name,
|
||||
User.created_at,
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
Call [`order_params()`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.order_params) to generate a FastAPI dependency that maps the query parameters to an [`OrderByClause`](../reference/crud.md#fastapi_toolsets.crud.factory.OrderByClause) expression:
|
||||
|
||||
```python
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi_toolsets.crud import OrderByClause
|
||||
|
||||
@router.get("")
|
||||
async def list_users(
|
||||
session: SessionDep,
|
||||
order_by: Annotated[OrderByClause | None, Depends(UserCrud.order_params())],
|
||||
) -> OffsetPaginatedResponse[UserRead]:
|
||||
return await UserCrud.offset_paginate(session=session, order_by=order_by, schema=UserRead)
|
||||
```
|
||||
|
||||
The dependency adds two query parameters to the endpoint:
|
||||
|
||||
| Parameter | Type |
|
||||
| ---------- | --------------- |
|
||||
| `order_by` | `str | null` |
|
||||
| `order` | `asc` or `desc` |
|
||||
|
||||
```
|
||||
GET /users?order_by=name&order=asc → ORDER BY users.name ASC
|
||||
GET /users?order_by=name&order=desc → ORDER BY users.name DESC
|
||||
```
|
||||
|
||||
An unknown `order_by` value raises [`InvalidOrderFieldError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidOrderFieldError) (HTTP 422).
|
||||
|
||||
You can also pass `order_fields` directly to `order_params()` to override the class-level defaults without modifying them:
|
||||
|
||||
```python
|
||||
UserOrderParams = UserCrud.order_params(order_fields=[User.name])
|
||||
```
|
||||
|
||||
## Relationship loading
|
||||
|
||||
!!! info "Added in `v1.1`"
|
||||
|
||||
By default, SQLAlchemy relationships are not loaded unless explicitly requested. Instead of using `lazy="selectin"` on model definitions (which is implicit and applies globally), define a `default_load_options` on the CRUD class to control loading strategy explicitly.
|
||||
|
||||
!!! warning
|
||||
Avoid using `lazy="selectin"` on model relationships. It fires silently on every query, cannot be disabled per-call, and can cause unexpected cascading loads through deep relationship chains. Use `default_load_options` instead.
|
||||
|
||||
```python
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
ArticleCrud = CrudFactory(
|
||||
model=Article,
|
||||
default_load_options=[
|
||||
selectinload(Article.category),
|
||||
selectinload(Article.tags),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
`default_load_options` applies automatically to all read operations (`get`, `first`, `get_multi`, `offset_paginate`, `cursor_paginate`). When `load_options` is passed at call-site, it **fully replaces** `default_load_options` for that query — giving you precise per-call control:
|
||||
|
||||
```python
|
||||
# Only loads category, tags are not loaded
|
||||
article = await ArticleCrud.get(
|
||||
session=session,
|
||||
filters=[Article.id == article_id],
|
||||
load_options=[selectinload(Article.category)],
|
||||
)
|
||||
|
||||
# Loads nothing — useful for write-then-refresh flows or lightweight checks
|
||||
articles = await ArticleCrud.get_multi(session=session, load_options=[])
|
||||
```
|
||||
|
||||
## Many-to-many relationships
|
||||
|
||||
Use `m2m_fields` to map schema fields containing lists of IDs to SQLAlchemy relationships. The CRUD class resolves and validates all IDs before persisting:
|
||||
|
||||
```python
|
||||
PostCrud = CrudFactory(
|
||||
model=Post,
|
||||
m2m_fields={"tag_ids": Post.tags},
|
||||
)
|
||||
|
||||
post = await PostCrud.create(session=session, obj=PostCreateSchema(title="Hello", tag_ids=[1, 2, 3]))
|
||||
```
|
||||
|
||||
## Upsert
|
||||
|
||||
Atomic `INSERT ... ON CONFLICT DO UPDATE` using PostgreSQL:
|
||||
|
||||
```python
|
||||
await UserCrud.upsert(
|
||||
session=session,
|
||||
obj=UserCreateSchema(email="alice@example.com", username="alice"),
|
||||
index_elements=[User.email],
|
||||
set_={"username"},
|
||||
)
|
||||
```
|
||||
|
||||
## Response serialization
|
||||
|
||||
!!! info "Added in `v1.1`"
|
||||
|
||||
Pass a Pydantic schema class to `create`, `get`, `update`, or `offset_paginate` to serialize the result directly into that schema and wrap it in a [`Response[schema]`](../reference/schemas.md#fastapi_toolsets.schemas.Response) or [`PaginatedResponse[schema]`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse):
|
||||
|
||||
```python
|
||||
class UserRead(PydanticBase):
|
||||
id: UUID
|
||||
username: str
|
||||
|
||||
@router.get(
|
||||
"/{uuid}",
|
||||
responses=generate_error_responses(NotFoundError),
|
||||
)
|
||||
async def get_user(session: SessionDep, uuid: UUID) -> Response[UserRead]:
|
||||
return await crud.UserCrud.get(
|
||||
session=session,
|
||||
filters=[User.id == uuid],
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
@router.get("")
|
||||
async def list_users(session: SessionDep, page: int = 1) -> OffsetPaginatedResponse[UserRead]:
|
||||
return await crud.UserCrud.offset_paginate(
|
||||
session=session,
|
||||
page=page,
|
||||
schema=UserRead,
|
||||
)
|
||||
```
|
||||
|
||||
The schema must have `from_attributes=True` (or inherit from [`PydanticBase`](../reference/schemas.md#fastapi_toolsets.schemas.PydanticBase)) so it can be built from SQLAlchemy model instances.
|
||||
|
||||
---
|
||||
|
||||
[:material-api: API Reference](../reference/crud.md)
|
||||
123
docs/module/db.md
Normal file
123
docs/module/db.md
Normal file
@@ -0,0 +1,123 @@
|
||||
# DB
|
||||
|
||||
SQLAlchemy async session management with transactions, table locking, and row-change polling.
|
||||
|
||||
!!! info
|
||||
This module has been coded and tested to be compatible with PostgreSQL only.
|
||||
|
||||
## Overview
|
||||
|
||||
The `db` module provides helpers to create FastAPI dependencies and context managers for `AsyncSession`, along with utilities for nested transactions, table lock and polling for row changes.
|
||||
|
||||
## Session dependency
|
||||
|
||||
Use [`create_db_dependency`](../reference/db.md#fastapi_toolsets.db.create_db_dependency) to create a FastAPI dependency that yields a session and auto-commits on success:
|
||||
|
||||
```python
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||
from fastapi_toolsets.db import create_db_dependency
|
||||
|
||||
engine = create_async_engine(url="postgresql+asyncpg://...", future=True)
|
||||
session_maker = async_sessionmaker(bind=engine, expire_on_commit=False)
|
||||
|
||||
get_db = create_db_dependency(session_maker=session_maker)
|
||||
|
||||
@router.get("/users")
|
||||
async def list_users(session: AsyncSession = Depends(get_db)):
|
||||
...
|
||||
```
|
||||
|
||||
## Session context manager
|
||||
|
||||
Use [`create_db_context`](../reference/db.md#fastapi_toolsets.db.create_db_context) for sessions outside request handlers (e.g. background tasks, CLI commands):
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.db import create_db_context
|
||||
|
||||
db_context = create_db_context(session_maker=session_maker)
|
||||
|
||||
async def seed():
|
||||
async with db_context() as session:
|
||||
...
|
||||
```
|
||||
|
||||
## Nested transactions
|
||||
|
||||
[`get_transaction`](../reference/db.md#fastapi_toolsets.db.get_transaction) handles savepoints automatically, allowing safe nesting:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.db import get_transaction
|
||||
|
||||
async def create_user_with_role(session=session):
|
||||
async with get_transaction(session=session):
|
||||
...
|
||||
async with get_transaction(session=session): # uses savepoint
|
||||
...
|
||||
```
|
||||
|
||||
## Table locking
|
||||
|
||||
[`lock_tables`](../reference/db.md#fastapi_toolsets.db.lock_tables) acquires PostgreSQL table-level locks before executing critical sections:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.db import lock_tables
|
||||
|
||||
async with lock_tables(session=session, tables=[User], mode="EXCLUSIVE"):
|
||||
# No other transaction can modify User until this block exits
|
||||
...
|
||||
```
|
||||
|
||||
Available lock modes are defined in [`LockMode`](../reference/db.md#fastapi_toolsets.db.LockMode): `ACCESS_SHARE`, `ROW_SHARE`, `ROW_EXCLUSIVE`, `SHARE_UPDATE_EXCLUSIVE`, `SHARE`, `SHARE_ROW_EXCLUSIVE`, `EXCLUSIVE`, `ACCESS_EXCLUSIVE`.
|
||||
|
||||
## Row-change polling
|
||||
|
||||
[`wait_for_row_change`](../reference/db.md#fastapi_toolsets.db.wait_for_row_change) polls a row until a specific column changes value, useful for waiting on async side effects:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.db import wait_for_row_change
|
||||
|
||||
# Wait up to 30s for order.status to change
|
||||
await wait_for_row_change(
|
||||
session=session,
|
||||
model=Order,
|
||||
pk_value=order_id,
|
||||
columns=[Order.status],
|
||||
interval=1.0,
|
||||
timeout=30.0,
|
||||
)
|
||||
```
|
||||
|
||||
## Creating a database
|
||||
|
||||
!!! info "Added in `v2.1`"
|
||||
|
||||
[`create_database`](../reference/db.md#fastapi_toolsets.db.create_database) creates a database at a given URL. It connects to *server_url* and issues a `CREATE DATABASE` statement:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.db import create_database
|
||||
|
||||
SERVER_URL = "postgresql+asyncpg://postgres:postgres@localhost/postgres"
|
||||
|
||||
await create_database(db_name="myapp_test", server_url=SERVER_URL)
|
||||
```
|
||||
|
||||
For test isolation with automatic cleanup, use [`create_worker_database`](../reference/pytest.md#fastapi_toolsets.pytest.utils.create_worker_database) from the `pytest` module instead — it handles drop-before, create, and drop-after automatically.
|
||||
|
||||
## Cleaning up tables
|
||||
|
||||
!!! info "Added in `v2.1`"
|
||||
|
||||
[`cleanup_tables`](../reference/db.md#fastapi_toolsets.db.cleanup_tables) truncates all tables:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.db import cleanup_tables
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def clean(db_session):
|
||||
yield
|
||||
await cleanup_tables(session=db_session, base=Base)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
[:material-api: API Reference](../reference/db.md)
|
||||
61
docs/module/dependencies.md
Normal file
61
docs/module/dependencies.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# Dependencies
|
||||
|
||||
FastAPI dependency factories for automatic model resolution from path and body parameters.
|
||||
|
||||
## Overview
|
||||
|
||||
The `dependencies` module provides two factory functions that create FastAPI dependencies to fetch a model instance from the database automatically — either from a path parameter or from a request body field — and inject it directly into your route handler.
|
||||
|
||||
## `PathDependency`
|
||||
|
||||
[`PathDependency`](../reference/dependencies.md#fastapi_toolsets.dependencies.PathDependency) resolves a model from a URL path parameter and injects it into the route handler. Raises [`NotFoundError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NotFoundError) automatically if the record does not exist.
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.dependencies import PathDependency
|
||||
|
||||
# Plain callable
|
||||
UserDep = PathDependency(model=User, field=User.id, session_dep=get_db)
|
||||
|
||||
# Annotated
|
||||
SessionDep = Annotated[AsyncSession, Depends(get_db)]
|
||||
UserDep = PathDependency(model=User, field=User.id, session_dep=SessionDep)
|
||||
|
||||
@router.get("/users/{user_id}")
|
||||
async def get_user(user: User = UserDep):
|
||||
return user
|
||||
```
|
||||
|
||||
By default the parameter name is inferred from the field (`user_id` for `User.id`). You can override it:
|
||||
|
||||
```python
|
||||
UserDep = PathDependency(model=User, field=User.id, session_dep=get_db, param_name="id")
|
||||
|
||||
@router.get("/users/{id}")
|
||||
async def get_user(user: User = UserDep):
|
||||
return user
|
||||
```
|
||||
|
||||
## `BodyDependency`
|
||||
|
||||
[`BodyDependency`](../reference/dependencies.md#fastapi_toolsets.dependencies.BodyDependency) resolves a model from a field in the request body. Useful when a body contains a foreign key and you want the full object injected:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.dependencies import BodyDependency
|
||||
|
||||
# Plain callable
|
||||
RoleDep = BodyDependency(model=Role, field=Role.id, session_dep=get_db, body_field="role_id")
|
||||
|
||||
# Annotated
|
||||
SessionDep = Annotated[AsyncSession, Depends(get_db)]
|
||||
RoleDep = BodyDependency(model=Role, field=Role.id, session_dep=SessionDep, body_field="role_id")
|
||||
|
||||
|
||||
@router.post("/users")
|
||||
async def create_user(body: UserCreateSchema, role: Role = RoleDep):
|
||||
user = User(username=body.username, role=role)
|
||||
...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
[:material-api: API Reference](../reference/dependencies.md)
|
||||
131
docs/module/exceptions.md
Normal file
131
docs/module/exceptions.md
Normal file
@@ -0,0 +1,131 @@
|
||||
# Exceptions
|
||||
|
||||
Structured API exceptions with consistent error responses and automatic OpenAPI documentation.
|
||||
|
||||
## Overview
|
||||
|
||||
The `exceptions` module provides a set of pre-built HTTP exceptions and a FastAPI exception handler that formats all errors — including validation errors — into a uniform [`ErrorResponse`](../reference/schemas.md#fastapi_toolsets.schemas.ErrorResponse).
|
||||
|
||||
## Setup
|
||||
|
||||
Register the exception handlers on your FastAPI app at startup:
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI
|
||||
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
||||
|
||||
app = FastAPI()
|
||||
init_exceptions_handlers(app=app)
|
||||
```
|
||||
|
||||
This registers handlers for:
|
||||
|
||||
- [`ApiException`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ApiException) — all custom exceptions below
|
||||
- `HTTPException` — Starlette/FastAPI HTTP errors
|
||||
- `RequestValidationError` — Pydantic request validation (422)
|
||||
- `ResponseValidationError` — Pydantic response validation (422)
|
||||
- `Exception` — unhandled errors (500)
|
||||
|
||||
It also patches `app.openapi()` to replace the default Pydantic 422 schema with a structured example matching the `ErrorResponse` format.
|
||||
|
||||
## Built-in exceptions
|
||||
|
||||
| Exception | Status | Default message |
|
||||
|-----------|--------|-----------------|
|
||||
| [`UnauthorizedError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.UnauthorizedError) | 401 | Unauthorized |
|
||||
| [`ForbiddenError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ForbiddenError) | 403 | Forbidden |
|
||||
| [`NotFoundError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NotFoundError) | 404 | Not Found |
|
||||
| [`ConflictError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ConflictError) | 409 | Conflict |
|
||||
| [`NoSearchableFieldsError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NoSearchableFieldsError) | 400 | No Searchable Fields |
|
||||
| [`InvalidFacetFilterError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError) | 400 | Invalid Facet Filter |
|
||||
| [`InvalidOrderFieldError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidOrderFieldError) | 422 | Invalid Order Field |
|
||||
|
||||
### Per-instance overrides
|
||||
|
||||
All built-in exceptions accept optional keyword arguments to customise the response for a specific raise site without changing the class defaults:
|
||||
|
||||
| Argument | Effect |
|
||||
|----------|--------|
|
||||
| `detail` | Overrides both `str(exc)` (log output) and the `message` field in the response body |
|
||||
| `desc` | Overrides the `description` field |
|
||||
| `data` | Overrides the `data` field |
|
||||
|
||||
```python
|
||||
raise NotFoundError(detail="User 42 not found", desc="No user with that ID exists in the database.")
|
||||
```
|
||||
|
||||
## Custom exceptions
|
||||
|
||||
Subclass [`ApiException`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ApiException) and define an `api_error` class variable:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.exceptions import ApiException
|
||||
from fastapi_toolsets.schemas import ApiError
|
||||
|
||||
class PaymentRequiredError(ApiException):
|
||||
api_error = ApiError(
|
||||
code=402,
|
||||
msg="Payment Required",
|
||||
desc="Your subscription has expired.",
|
||||
err_code="BILLING-402",
|
||||
)
|
||||
```
|
||||
|
||||
!!! warning
|
||||
Subclasses that do not define `api_error` raise a `TypeError` at **class creation time**, not at raise time.
|
||||
|
||||
### Custom `__init__`
|
||||
|
||||
Override `__init__` to compute `detail`, `desc`, or `data` dynamically, then delegate to `super().__init__()`:
|
||||
|
||||
```python
|
||||
class OrderValidationError(ApiException):
|
||||
api_error = ApiError(
|
||||
code=422,
|
||||
msg="Order Validation Failed",
|
||||
desc="One or more order fields are invalid.",
|
||||
err_code="ORDER-422",
|
||||
)
|
||||
|
||||
def __init__(self, *field_errors: str) -> None:
|
||||
super().__init__(
|
||||
f"{len(field_errors)} validation error(s)",
|
||||
desc=", ".join(field_errors),
|
||||
data={"errors": [{"message": e} for e in field_errors]},
|
||||
)
|
||||
```
|
||||
|
||||
### Intermediate base classes
|
||||
|
||||
Use `abstract=True` when creating a shared base that is not meant to be raised directly:
|
||||
|
||||
```python
|
||||
class BillingError(ApiException, abstract=True):
|
||||
"""Base for all billing-related errors."""
|
||||
|
||||
class PaymentRequiredError(BillingError):
|
||||
api_error = ApiError(code=402, msg="Payment Required", desc="...", err_code="BILLING-402")
|
||||
|
||||
class SubscriptionExpiredError(BillingError):
|
||||
api_error = ApiError(code=402, msg="Subscription Expired", desc="...", err_code="BILLING-402-EXP")
|
||||
```
|
||||
|
||||
## OpenAPI response documentation
|
||||
|
||||
Use [`generate_error_responses`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.generate_error_responses) to add error schemas to your endpoint's OpenAPI spec:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.exceptions import generate_error_responses, NotFoundError, ForbiddenError
|
||||
|
||||
@router.get(
|
||||
"/users/{id}",
|
||||
responses=generate_error_responses(NotFoundError, ForbiddenError),
|
||||
)
|
||||
async def get_user(...): ...
|
||||
```
|
||||
|
||||
Multiple exceptions sharing the same HTTP status code are grouped under one entry, each appearing as a named example keyed by its `err_code`. This keeps the OpenAPI UI readable when several error variants map to the same status.
|
||||
|
||||
---
|
||||
|
||||
[:material-api: API Reference](../reference/exceptions.md)
|
||||
123
docs/module/fixtures.md
Normal file
123
docs/module/fixtures.md
Normal file
@@ -0,0 +1,123 @@
|
||||
# Fixtures
|
||||
|
||||
Dependency-aware database seeding with context-based loading strategies.
|
||||
|
||||
## Overview
|
||||
|
||||
The `fixtures` module lets you define named fixtures with dependencies between them, then load them into the database in the correct order. Fixtures can be scoped to contexts (e.g. base data, testing data) so that only the relevant ones are loaded for each environment.
|
||||
|
||||
## Defining fixtures
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.fixtures import FixtureRegistry, Context
|
||||
|
||||
fixtures = FixtureRegistry()
|
||||
|
||||
@fixtures.register
|
||||
def roles():
|
||||
return [
|
||||
Role(id=1, name="admin"),
|
||||
Role(id=2, name="user"),
|
||||
]
|
||||
|
||||
@fixtures.register(depends_on=["roles"], contexts=[Context.TESTING])
|
||||
def test_users():
|
||||
return [
|
||||
User(id=1, username="alice", role_id=1),
|
||||
User(id=2, username="bob", role_id=2),
|
||||
]
|
||||
```
|
||||
|
||||
Dependencies declared via `depends_on` are resolved topologically — `roles` will always be loaded before `test_users`.
|
||||
|
||||
## Loading fixtures
|
||||
|
||||
By context with [`load_fixtures_by_context`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.load_fixtures_by_context):
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.fixtures import load_fixtures_by_context
|
||||
|
||||
async with db_context() as session:
|
||||
await load_fixtures_by_context(session=session, registry=fixtures, context=Context.TESTING)
|
||||
```
|
||||
|
||||
Directly with [`load_fixtures`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.load_fixtures):
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.fixtures import load_fixtures
|
||||
|
||||
async with db_context() as session:
|
||||
await load_fixtures(session=session, registry=fixtures)
|
||||
```
|
||||
|
||||
## Contexts
|
||||
|
||||
[`Context`](../reference/fixtures.md#fastapi_toolsets.fixtures.enum.Context) is an enum with predefined values:
|
||||
|
||||
| Context | Description |
|
||||
|---------|-------------|
|
||||
| `Context.BASE` | Core data required in all environments |
|
||||
| `Context.TESTING` | Data only loaded during tests |
|
||||
| `Context.PRODUCTION` | Data only loaded in production |
|
||||
|
||||
A fixture with no `contexts` defined takes `Context.BASE` by default.
|
||||
|
||||
## Load strategies
|
||||
|
||||
[`LoadStrategy`](../reference/fixtures.md#fastapi_toolsets.fixtures.enum.LoadStrategy) controls how the fixture loader handles rows that already exist:
|
||||
|
||||
| Strategy | Description |
|
||||
|----------|-------------|
|
||||
| `LoadStrategy.INSERT` | Insert only, fail on duplicates |
|
||||
| `LoadStrategy.UPSERT` | Insert or update on conflict |
|
||||
| `LoadStrategy.SKIP` | Skip rows that already exist |
|
||||
|
||||
## Merging registries
|
||||
|
||||
Split fixtures definitions across modules and merge them:
|
||||
|
||||
```python
|
||||
from myapp.fixtures.dev import dev_fixtures
|
||||
from myapp.fixtures.prod import prod_fixtures
|
||||
|
||||
fixtures = fixturesRegistry()
|
||||
fixtures.include_registry(registry=dev_fixtures)
|
||||
fixtures.include_registry(registry=prod_fixtures)
|
||||
|
||||
## Pytest integration
|
||||
|
||||
Use [`register_fixtures`](../reference/pytest.md#fastapi_toolsets.pytest.plugin.register_fixtures) to expose each fixture in your registry as an injectable pytest fixture named `fixture_{name}` by default:
|
||||
|
||||
```python
|
||||
# conftest.py
|
||||
import pytest
|
||||
from fastapi_toolsets.pytest import create_db_session, register_fixtures
|
||||
from app.fixtures import registry
|
||||
from app.models import Base
|
||||
|
||||
DATABASE_URL = "postgresql+asyncpg://user:pass@localhost/test_db"
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session():
|
||||
async with create_db_session(database_url=DATABASE_URL, base=Base, cleanup=True) as session:
|
||||
yield session
|
||||
|
||||
register_fixtures(registry=registry, namespace=globals())
|
||||
```
|
||||
|
||||
```python
|
||||
# test_users.py
|
||||
async def test_user_can_login(fixture_users: list[User], fixture_roles: list[Role]):
|
||||
...
|
||||
```
|
||||
|
||||
|
||||
The load order is resolved automatically from the `depends_on` declarations in your registry. Each generated fixture receives `db_session` as a dependency and returns the list of loaded model instances.
|
||||
|
||||
## CLI integration
|
||||
|
||||
Fixtures can be triggered from the CLI. See the [CLI module](cli.md) for setup instructions.
|
||||
|
||||
---
|
||||
|
||||
[:material-api: API Reference](../reference/fixtures.md)
|
||||
39
docs/module/logger.md
Normal file
39
docs/module/logger.md
Normal file
@@ -0,0 +1,39 @@
|
||||
# Logger
|
||||
|
||||
Lightweight logging utilities with consistent formatting and uvicorn integration.
|
||||
|
||||
## Overview
|
||||
|
||||
The `logger` module provides two helpers: one to configure the root logger (and uvicorn loggers) at startup, and one to retrieve a named logger anywhere in your codebase.
|
||||
|
||||
## Setup
|
||||
|
||||
Call [`configure_logging`](../reference/logger.md#fastapi_toolsets.logger.configure_logging) once at application startup:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.logger import configure_logging
|
||||
|
||||
configure_logging(level="INFO")
|
||||
```
|
||||
|
||||
This sets up a stdout handler with a consistent format and also configures uvicorn's access and error loggers so all log output shares the same style.
|
||||
|
||||
## Getting a logger
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.logger import get_logger
|
||||
|
||||
logger = get_logger(name=__name__)
|
||||
logger.info("User created")
|
||||
```
|
||||
|
||||
When called without arguments, [`get_logger`](../reference/logger.md#fastapi_toolsets.logger.get_logger) auto-detects the caller's module name via frame inspection:
|
||||
|
||||
```python
|
||||
# Equivalent to get_logger(name=__name__)
|
||||
logger = get_logger()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
[:material-api: API Reference](../reference/logger.md)
|
||||
109
docs/module/metrics.md
Normal file
109
docs/module/metrics.md
Normal file
@@ -0,0 +1,109 @@
|
||||
# Metrics
|
||||
|
||||
Prometheus metrics integration with a decorator-based registry and multi-process support.
|
||||
|
||||
## Installation
|
||||
|
||||
=== "uv"
|
||||
``` bash
|
||||
uv add "fastapi-toolsets[metrics]"
|
||||
```
|
||||
|
||||
=== "pip"
|
||||
``` bash
|
||||
pip install "fastapi-toolsets[metrics]"
|
||||
```
|
||||
|
||||
## Overview
|
||||
|
||||
The `metrics` module provides a [`MetricsRegistry`](../reference/metrics.md#fastapi_toolsets.metrics.registry.MetricsRegistry) to declare Prometheus metrics with decorators, and an [`init_metrics`](../reference/metrics.md#fastapi_toolsets.metrics.handler.init_metrics) function to mount a `/metrics` endpoint on your FastAPI app.
|
||||
|
||||
## Setup
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI
|
||||
from fastapi_toolsets.metrics import MetricsRegistry, init_metrics
|
||||
|
||||
app = FastAPI()
|
||||
metrics = MetricsRegistry()
|
||||
|
||||
init_metrics(app=app, registry=metrics)
|
||||
```
|
||||
|
||||
This mounts the `/metrics` endpoint that Prometheus can scrape.
|
||||
|
||||
## Declaring metrics
|
||||
|
||||
### Providers
|
||||
|
||||
Providers are called once at startup by `init_metrics`. The return value (the Prometheus metric object) is stored in the registry and can be retrieved later with [`registry.get(name)`](../reference/metrics.md#fastapi_toolsets.metrics.registry.MetricsRegistry.get).
|
||||
|
||||
Use providers when you want **deferred initialization**: the Prometheus metric is not registered with the global `CollectorRegistry` until `init_metrics` runs, not at import time. This is particularly useful for testing — importing the module in a test suite without calling `init_metrics` leaves no metrics registered, avoiding cross-test pollution.
|
||||
|
||||
It is also useful when metrics are defined across multiple modules and merged with `include_registry`: any code that needs a metric can call `metrics.get()` on the shared registry instead of importing the metric directly from its origin module.
|
||||
|
||||
If neither of these applies to you, declaring metrics at module level (e.g. `HTTP_REQUESTS = Counter(...)`) is simpler and equally valid.
|
||||
|
||||
```python
|
||||
from prometheus_client import Counter, Histogram
|
||||
|
||||
@metrics.register
|
||||
def http_requests():
|
||||
return Counter("http_requests_total", "Total HTTP requests", ["method", "status"])
|
||||
|
||||
@metrics.register
|
||||
def request_duration():
|
||||
return Histogram("request_duration_seconds", "Request duration")
|
||||
```
|
||||
|
||||
To use a provider's metric elsewhere (e.g. in a middleware), call `metrics.get()` inside the handler — **not** at module level, as providers are only initialized when `init_metrics` runs:
|
||||
|
||||
```python
|
||||
async def metrics_middleware(request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
metrics.get("http_requests").labels(
|
||||
method=request.method, status=response.status_code
|
||||
).inc()
|
||||
return response
|
||||
```
|
||||
|
||||
### Collectors
|
||||
|
||||
Collectors are called on every scrape. Use them for metrics that reflect current state (e.g. gauges).
|
||||
|
||||
!!! warning "Declare the metric at module level"
|
||||
Do **not** instantiate the Prometheus metric inside the collector function. Doing so recreates it on every scrape, raising `ValueError: Duplicated timeseries in CollectorRegistry`. Declare it once at module level instead:
|
||||
|
||||
```python
|
||||
from prometheus_client import Gauge
|
||||
|
||||
_queue_depth = Gauge("queue_depth", "Current queue depth")
|
||||
|
||||
@metrics.register(collect=True)
|
||||
def collect_queue_depth():
|
||||
_queue_depth.set(get_current_queue_depth())
|
||||
```
|
||||
|
||||
## Merging registries
|
||||
|
||||
Split metrics definitions across modules and merge them:
|
||||
|
||||
```python
|
||||
from myapp.metrics.http import http_metrics
|
||||
from myapp.metrics.db import db_metrics
|
||||
|
||||
metrics = MetricsRegistry()
|
||||
metrics.include_registry(registry=http_metrics)
|
||||
metrics.include_registry(registry=db_metrics)
|
||||
```
|
||||
|
||||
## Multi-process mode
|
||||
|
||||
Multi-process support is enabled automatically when the `PROMETHEUS_MULTIPROC_DIR` environment variable is set. No code changes are required.
|
||||
|
||||
!!! warning "Environment variable name"
|
||||
The correct variable is `PROMETHEUS_MULTIPROC_DIR` (not `PROMETHEUS_MULTIPROCESS_DIR`).
|
||||
|
||||
---
|
||||
|
||||
[:material-api: API Reference](../reference/metrics.md)
|
||||
255
docs/module/models.md
Normal file
255
docs/module/models.md
Normal file
@@ -0,0 +1,255 @@
|
||||
# Models
|
||||
|
||||
!!! info "Added in `v2.0`"
|
||||
|
||||
Reusable SQLAlchemy 2.0 mixins for common column patterns, designed to be composed freely on any `DeclarativeBase` model.
|
||||
|
||||
## Overview
|
||||
|
||||
The `models` module provides mixins that each add a single, well-defined column behaviour. They work with standard SQLAlchemy 2.0 declarative syntax and are fully compatible with `AsyncSession`.
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.models import UUIDMixin, TimestampMixin
|
||||
|
||||
class Article(Base, UUIDMixin, TimestampMixin):
|
||||
__tablename__ = "articles"
|
||||
|
||||
title: Mapped[str]
|
||||
content: Mapped[str]
|
||||
```
|
||||
|
||||
All timestamp columns are timezone-aware (`TIMESTAMPTZ`). All defaults are server-side (`clock_timestamp()`), so they are also applied when inserting rows via raw SQL outside the ORM.
|
||||
|
||||
## Mixins
|
||||
|
||||
### [`UUIDMixin`](../reference/models.md#fastapi_toolsets.models.UUIDMixin)
|
||||
|
||||
Adds a `id: UUID` primary key generated server-side by PostgreSQL using `gen_random_uuid()`. The value is retrieved via `RETURNING` after insert, so it is available on the Python object immediately after `flush()`.
|
||||
|
||||
!!! warning "Requires PostgreSQL 13+"
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.models import UUIDMixin
|
||||
|
||||
class User(Base, UUIDMixin):
|
||||
__tablename__ = "users"
|
||||
|
||||
username: Mapped[str]
|
||||
|
||||
# id is None before flush
|
||||
user = User(username="alice")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
print(user.id) # UUID('...')
|
||||
```
|
||||
|
||||
### [`UUIDv7Mixin`](../reference/models.md#fastapi_toolsets.models.UUIDv7Mixin)
|
||||
|
||||
!!! info "Added in `v2.3`"
|
||||
|
||||
Adds a `id: UUID` primary key generated server-side by PostgreSQL using `uuidv7()`. It's a time-ordered UUID format that encodes a millisecond-precision timestamp in the most significant bits, making it naturally sortable and index-friendly.
|
||||
|
||||
!!! warning "Requires PostgreSQL 18+"
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.models import UUIDv7Mixin
|
||||
|
||||
class Event(Base, UUIDv7Mixin):
|
||||
__tablename__ = "events"
|
||||
|
||||
name: Mapped[str]
|
||||
|
||||
# id is None before flush
|
||||
event = Event(name="user.signup")
|
||||
session.add(event)
|
||||
await session.flush()
|
||||
print(event.id) # UUID('019...')
|
||||
```
|
||||
|
||||
### [`CreatedAtMixin`](../reference/models.md#fastapi_toolsets.models.CreatedAtMixin)
|
||||
|
||||
Adds a `created_at: datetime` column set to `clock_timestamp()` on insert. The column has no `onupdate` hook — it is intentionally immutable after the row is created.
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.models import UUIDMixin, CreatedAtMixin
|
||||
|
||||
class Order(Base, UUIDMixin, CreatedAtMixin):
|
||||
__tablename__ = "orders"
|
||||
|
||||
total: Mapped[float]
|
||||
```
|
||||
|
||||
### [`UpdatedAtMixin`](../reference/models.md#fastapi_toolsets.models.UpdatedAtMixin)
|
||||
|
||||
Adds an `updated_at: datetime` column set to `clock_timestamp()` on insert and automatically updated to `clock_timestamp()` on every ORM-level update (via SQLAlchemy's `onupdate` hook).
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.models import UUIDMixin, UpdatedAtMixin
|
||||
|
||||
class Post(Base, UUIDMixin, UpdatedAtMixin):
|
||||
__tablename__ = "posts"
|
||||
|
||||
title: Mapped[str]
|
||||
|
||||
post = Post(title="Hello")
|
||||
await session.flush()
|
||||
await session.refresh(post)
|
||||
|
||||
post.title = "Hello World"
|
||||
await session.flush()
|
||||
await session.refresh(post)
|
||||
print(post.updated_at)
|
||||
```
|
||||
|
||||
!!! note
|
||||
`updated_at` is updated by SQLAlchemy at ORM flush time. If you update rows via raw SQL (e.g. `UPDATE posts SET ...`), the column will **not** be updated automatically — use a database trigger if you need that guarantee.
|
||||
|
||||
### [`TimestampMixin`](../reference/models.md#fastapi_toolsets.models.TimestampMixin)
|
||||
|
||||
Convenience mixin that combines [`CreatedAtMixin`](../reference/models.md#fastapi_toolsets.models.CreatedAtMixin) and [`UpdatedAtMixin`](../reference/models.md#fastapi_toolsets.models.UpdatedAtMixin). Equivalent to inheriting both.
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.models import UUIDMixin, TimestampMixin
|
||||
|
||||
class Article(Base, UUIDMixin, TimestampMixin):
|
||||
__tablename__ = "articles"
|
||||
|
||||
title: Mapped[str]
|
||||
```
|
||||
|
||||
### [`WatchedFieldsMixin`](../reference/models.md#fastapi_toolsets.models.WatchedFieldsMixin)
|
||||
|
||||
!!! info "Added in `v2.4`"
|
||||
|
||||
`WatchedFieldsMixin` provides lifecycle callbacks that fire **after commit** — meaning the row is durably persisted when your callback runs. If the transaction rolls back, no callback fires.
|
||||
|
||||
Three callbacks are available, each corresponding to a [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) value:
|
||||
|
||||
| Callback | Event | Trigger |
|
||||
|---|---|---|
|
||||
| `on_create()` | `ModelEvent.CREATE` | After `INSERT` |
|
||||
| `on_delete()` | `ModelEvent.DELETE` | After `DELETE` |
|
||||
| `on_update(changes)` | `ModelEvent.UPDATE` | After `UPDATE` on a watched field |
|
||||
|
||||
Server-side defaults (e.g. `id`, `created_at`) are fully populated in all callbacks. All callbacks support both `async def` and plain `def`. Use `@watch` to restrict which fields trigger `on_update`:
|
||||
|
||||
| Decorator | `on_update` behaviour |
|
||||
|---|---|
|
||||
| `@watch("status", "role")` | Only fires when `status` or `role` changes |
|
||||
| *(no decorator)* | Fires when **any** mapped field changes |
|
||||
|
||||
`@watch` is inherited through the class hierarchy. If a subclass does not declare its own `@watch`, it uses the filter from the nearest decorated parent. Applying `@watch` on the subclass overrides the parent's filter:
|
||||
|
||||
```python
|
||||
@watch("status")
|
||||
class Order(Base, UUIDMixin, WatchedFieldsMixin):
|
||||
...
|
||||
|
||||
class UrgentOrder(Order):
|
||||
# inherits @watch("status") — on_update fires only for status changes
|
||||
...
|
||||
|
||||
@watch("priority")
|
||||
class PriorityOrder(Order):
|
||||
# overrides parent — on_update fires only for priority changes
|
||||
...
|
||||
```
|
||||
|
||||
#### Option 1 — catch-all with `on_event`
|
||||
|
||||
Override `on_event` to handle all event types in one place. The specific methods delegate here by default:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.models import ModelEvent, UUIDMixin, WatchedFieldsMixin, watch
|
||||
|
||||
@watch("status")
|
||||
class Order(Base, UUIDMixin, WatchedFieldsMixin):
|
||||
__tablename__ = "orders"
|
||||
|
||||
status: Mapped[str]
|
||||
|
||||
async def on_event(self, event: ModelEvent, changes: dict | None = None) -> None:
|
||||
if event == ModelEvent.CREATE:
|
||||
await notify_new_order(self.id)
|
||||
elif event == ModelEvent.DELETE:
|
||||
await notify_order_cancelled(self.id)
|
||||
elif event == ModelEvent.UPDATE:
|
||||
await notify_status_change(self.id, changes["status"])
|
||||
```
|
||||
|
||||
#### Option 2 — targeted overrides
|
||||
|
||||
Override individual methods for more focused logic:
|
||||
|
||||
```python
|
||||
@watch("status")
|
||||
class Order(Base, UUIDMixin, WatchedFieldsMixin):
|
||||
__tablename__ = "orders"
|
||||
|
||||
status: Mapped[str]
|
||||
|
||||
async def on_create(self) -> None:
|
||||
await notify_new_order(self.id)
|
||||
|
||||
async def on_delete(self) -> None:
|
||||
await notify_order_cancelled(self.id)
|
||||
|
||||
async def on_update(self, changes: dict) -> None:
|
||||
if "status" in changes:
|
||||
old = changes["status"]["old"]
|
||||
new = changes["status"]["new"]
|
||||
await notify_status_change(self.id, old, new)
|
||||
```
|
||||
|
||||
#### Field changes format
|
||||
|
||||
The `changes` dict maps each watched field that changed to `{"old": ..., "new": ...}`. Only fields that actually changed are included:
|
||||
|
||||
```python
|
||||
# status changed → {"status": {"old": "pending", "new": "shipped"}}
|
||||
# two fields changed → {"status": {...}, "assigned_to": {...}}
|
||||
```
|
||||
|
||||
!!! info "Multiple flushes in one transaction are merged: the earliest `old` and latest `new` are preserved, and `on_update` fires only once per commit."
|
||||
|
||||
!!! warning "Callbacks fire only for ORM-level changes. Rows updated via raw SQL (`UPDATE ... SET ...`) are not detected."
|
||||
|
||||
!!! warning "Callbacks fire after the **outermost** transaction commits."
|
||||
If you create several related objects using `CrudFactory.create` and need
|
||||
callbacks to see all of them (including associations), wrap the whole
|
||||
operation in a single [`get_transaction`](db.md) block. Without it, each
|
||||
`create` call commits independently and `on_create` fires before the
|
||||
remaining objects exist.
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.db import get_transaction
|
||||
|
||||
async with get_transaction(session):
|
||||
order = await OrderCrud.create(session, order_data)
|
||||
item = await ItemCrud.create(session, item_data)
|
||||
await session.refresh(order, attribute_names=["items"])
|
||||
order.items.append(item)
|
||||
# on_create fires here for both order and item,
|
||||
# with the full association already committed.
|
||||
```
|
||||
|
||||
## Composing mixins
|
||||
|
||||
All mixins can be combined in any order. The only constraint is that exactly one primary key must be defined — either via `UUIDMixin` or directly on the model.
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.models import UUIDMixin, TimestampMixin
|
||||
|
||||
class Event(Base, UUIDMixin, TimestampMixin):
|
||||
__tablename__ = "events"
|
||||
name: Mapped[str]
|
||||
|
||||
class Counter(Base, UpdatedAtMixin):
|
||||
__tablename__ = "counters"
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
value: Mapped[int]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
[:material-api: API Reference](../reference/models.md)
|
||||
98
docs/module/pytest.md
Normal file
98
docs/module/pytest.md
Normal file
@@ -0,0 +1,98 @@
|
||||
# Pytest
|
||||
|
||||
Testing helpers for FastAPI applications with async client, database sessions, and parallel worker support.
|
||||
|
||||
## Installation
|
||||
|
||||
=== "uv"
|
||||
``` bash
|
||||
uv add "fastapi-toolsets[pytest]"
|
||||
```
|
||||
|
||||
=== "pip"
|
||||
``` bash
|
||||
pip install "fastapi-toolsets[pytest]"
|
||||
```
|
||||
|
||||
## Overview
|
||||
|
||||
The `pytest` module provides utilities for setting up async test clients, managing test database sessions, and supporting parallel test execution with `pytest-xdist`.
|
||||
|
||||
## Creating an async client
|
||||
|
||||
Use [`create_async_client`](../reference/pytest.md#fastapi_toolsets.pytest.utils.create_async_client) to get an `httpx.AsyncClient` configured for your FastAPI app:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.pytest import create_async_client
|
||||
|
||||
@pytest.fixture
|
||||
async def http_client(db_session):
|
||||
async def _override_get_db():
|
||||
yield db_session
|
||||
|
||||
async with create_async_client(
|
||||
app=app,
|
||||
base_url="http://127.0.0.1/api/v1",
|
||||
dependency_overrides={get_db: _override_get_db},
|
||||
) as c:
|
||||
yield c
|
||||
```
|
||||
|
||||
## Database sessions in tests
|
||||
|
||||
Use [`create_db_session`](../reference/pytest.md#fastapi_toolsets.pytest.utils.create_db_session) to create an isolated `AsyncSession` for a test, combined with [`create_worker_database`](../reference/pytest.md#fastapi_toolsets.pytest.utils.create_worker_database) to set up a per-worker database:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.pytest import create_worker_database, create_db_session
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def worker_db_url():
|
||||
async with create_worker_database(
|
||||
database_url=str(settings.SQLALCHEMY_DATABASE_URI)
|
||||
) as url:
|
||||
yield url
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(worker_db_url):
|
||||
async with create_db_session(
|
||||
database_url=worker_db_url, base=Base, cleanup=True
|
||||
) as session:
|
||||
yield session
|
||||
```
|
||||
|
||||
!!! info
|
||||
In this example, the database is reset between each test using the argument `cleanup=True`.
|
||||
|
||||
Use [`worker_database_url`](../reference/pytest.md#fastapi_toolsets.pytest.utils.worker_database_url) to derive the per-worker URL manually if needed:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.pytest import worker_database_url
|
||||
|
||||
url = worker_database_url("postgresql+asyncpg://user:pass@localhost/test_db", default_test_db="test")
|
||||
# e.g. "postgresql+asyncpg://user:pass@localhost/test_db_gw0" under xdist
|
||||
```
|
||||
|
||||
## Parallel testing with pytest-xdist
|
||||
|
||||
The examples above are already compatible with parallel test execution with `pytest-xdist`.
|
||||
|
||||
## Cleaning up tables
|
||||
|
||||
!!! warning
|
||||
Since `V2.1.0` `cleanup_tables` now live in `fastapi_toolsets.db`. For backward compatibility the function is still available in `fastapi_toolsets.pytest`, but this will be remove in `V3.0.0`.
|
||||
|
||||
If you want to manually clean up a database you can use [`cleanup_tables`](../reference/db.md#fastapi_toolsets.db.cleanup_tables), this will truncate all tables between tests for fast isolation:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.db import cleanup_tables
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def clean(db_session):
|
||||
yield
|
||||
await cleanup_tables(session=db_session, base=Base)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
[:material-api: API Reference](../reference/pytest.md)
|
||||
140
docs/module/schemas.md
Normal file
140
docs/module/schemas.md
Normal file
@@ -0,0 +1,140 @@
|
||||
# Schemas
|
||||
|
||||
Standardized Pydantic response models for consistent API responses across your FastAPI application.
|
||||
|
||||
## Overview
|
||||
|
||||
The `schemas` module provides generic response wrappers that enforce a uniform response structure. All models use `from_attributes=True` for ORM compatibility and `validate_assignment=True` for runtime type safety.
|
||||
|
||||
## Response models
|
||||
|
||||
### [`Response[T]`](../reference/schemas.md#fastapi_toolsets.schemas.Response)
|
||||
|
||||
The most common wrapper for a single resource response.
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
@router.get("/users/{id}")
|
||||
async def get_user(user: User = UserDep) -> Response[UserSchema]:
|
||||
return Response(data=user, message="User retrieved")
|
||||
```
|
||||
|
||||
### Paginated response models
|
||||
|
||||
Three classes wrap paginated list results. Pick the one that matches your endpoint's strategy:
|
||||
|
||||
| Class | `pagination` type | `pagination_type` field | Use when |
|
||||
|---|---|---|---|
|
||||
| [`OffsetPaginatedResponse[T]`](#offsetpaginatedresponset) | `OffsetPagination` | `"offset"` (fixed) | endpoint always uses offset |
|
||||
| [`CursorPaginatedResponse[T]`](#cursorpaginatedresponset) | `CursorPagination` | `"cursor"` (fixed) | endpoint always uses cursor |
|
||||
| [`PaginatedResponse[T]`](#paginatedresponset) | `OffsetPagination \| CursorPagination` | — | unified endpoint supporting both strategies |
|
||||
|
||||
#### [`OffsetPaginatedResponse[T]`](../reference/schemas.md#fastapi_toolsets.schemas.OffsetPaginatedResponse)
|
||||
|
||||
!!! info "Added in `v2.3.0`"
|
||||
|
||||
Use as the return type when the endpoint always uses [`offset_paginate`](crud.md#offset-pagination). The `pagination` field is guaranteed to be an [`OffsetPagination`](../reference/schemas.md#fastapi_toolsets.schemas.OffsetPagination) object; the response always includes a `pagination_type: "offset"` discriminator.
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.schemas import OffsetPaginatedResponse
|
||||
|
||||
@router.get("/users")
|
||||
async def list_users(
|
||||
page: int = 1,
|
||||
items_per_page: int = 20,
|
||||
) -> OffsetPaginatedResponse[UserSchema]:
|
||||
return await UserCrud.offset_paginate(
|
||||
session, page=page, items_per_page=items_per_page, schema=UserSchema
|
||||
)
|
||||
```
|
||||
|
||||
**Response shape:**
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"pagination_type": "offset",
|
||||
"data": ["..."],
|
||||
"pagination": {
|
||||
"total_count": 100,
|
||||
"page": 1,
|
||||
"items_per_page": 20,
|
||||
"has_more": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### [`CursorPaginatedResponse[T]`](../reference/schemas.md#fastapi_toolsets.schemas.CursorPaginatedResponse)
|
||||
|
||||
!!! info "Added in `v2.3.0`"
|
||||
|
||||
Use as the return type when the endpoint always uses [`cursor_paginate`](crud.md#cursor-pagination). The `pagination` field is guaranteed to be a [`CursorPagination`](../reference/schemas.md#fastapi_toolsets.schemas.CursorPagination) object; the response always includes a `pagination_type: "cursor"` discriminator.
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.schemas import CursorPaginatedResponse
|
||||
|
||||
@router.get("/events")
|
||||
async def list_events(
|
||||
cursor: str | None = None,
|
||||
items_per_page: int = 20,
|
||||
) -> CursorPaginatedResponse[EventSchema]:
|
||||
return await EventCrud.cursor_paginate(
|
||||
session, cursor=cursor, items_per_page=items_per_page, schema=EventSchema
|
||||
)
|
||||
```
|
||||
|
||||
**Response shape:**
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"pagination_type": "cursor",
|
||||
"data": ["..."],
|
||||
"pagination": {
|
||||
"next_cursor": "eyJpZCI6IDQyfQ==",
|
||||
"prev_cursor": null,
|
||||
"items_per_page": 20,
|
||||
"has_more": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### [`PaginatedResponse[T]`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse)
|
||||
|
||||
Return type for endpoints that support **both** pagination strategies via a `pagination_type` query parameter (using [`paginate()`](crud.md#unified-paginate--both-strategies-on-one-endpoint)).
|
||||
|
||||
When used as a return annotation, `PaginatedResponse[T]` automatically expands to `Annotated[Union[CursorPaginatedResponse[T], OffsetPaginatedResponse[T]], Field(discriminator="pagination_type")]`, so FastAPI emits a proper `oneOf` + discriminator in the OpenAPI schema with no extra boilerplate:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.crud import PaginationType
|
||||
from fastapi_toolsets.schemas import PaginatedResponse
|
||||
|
||||
@router.get("/users")
|
||||
async def list_users(
|
||||
pagination_type: PaginationType = PaginationType.OFFSET,
|
||||
page: int = 1,
|
||||
cursor: str | None = None,
|
||||
items_per_page: int = 20,
|
||||
) -> PaginatedResponse[UserSchema]:
|
||||
return await UserCrud.paginate(
|
||||
session,
|
||||
pagination_type=pagination_type,
|
||||
page=page,
|
||||
cursor=cursor,
|
||||
items_per_page=items_per_page,
|
||||
schema=UserSchema,
|
||||
)
|
||||
```
|
||||
|
||||
#### Pagination metadata models
|
||||
|
||||
The optional `filter_attributes` field is populated when `facet_fields` are configured on the CRUD class (see [Filter attributes](crud.md#filter-attributes-facets)). It is `None` by default and can be hidden from API responses with `response_model_exclude_none=True`.
|
||||
|
||||
### [`ErrorResponse`](../reference/schemas.md#fastapi_toolsets.schemas.ErrorResponse)
|
||||
|
||||
Returned automatically by the exceptions handler.
|
||||
|
||||
---
|
||||
|
||||
[:material-api: API Reference](../reference/schemas.md)
|
||||
7
docs/overrides/main.html
Normal file
7
docs/overrides/main.html
Normal file
@@ -0,0 +1,7 @@
|
||||
{% extends "base.html" %} {% block extrahead %}
|
||||
<script
|
||||
defer
|
||||
src="https://analytics.d3vyce.fr/script.js"
|
||||
data-website-id="338b8816-7b99-4c6a-82f3-15595be3fd47"
|
||||
></script>
|
||||
{{ super() }} {% endblock %}
|
||||
27
docs/reference/cli.md
Normal file
27
docs/reference/cli.md
Normal file
@@ -0,0 +1,27 @@
|
||||
# `cli`
|
||||
|
||||
Here's the reference for the CLI configuration helpers used to load settings from `pyproject.toml`.
|
||||
|
||||
You can import them directly from `fastapi_toolsets.cli.config`:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.cli.config import (
|
||||
import_from_string,
|
||||
get_config_value,
|
||||
get_fixtures_registry,
|
||||
get_db_context,
|
||||
get_custom_cli,
|
||||
)
|
||||
```
|
||||
|
||||
## ::: fastapi_toolsets.cli.config.import_from_string
|
||||
|
||||
## ::: fastapi_toolsets.cli.config.get_config_value
|
||||
|
||||
## ::: fastapi_toolsets.cli.config.get_fixtures_registry
|
||||
|
||||
## ::: fastapi_toolsets.cli.config.get_db_context
|
||||
|
||||
## ::: fastapi_toolsets.cli.config.get_custom_cli
|
||||
|
||||
## ::: fastapi_toolsets.cli.utils.async_command
|
||||
20
docs/reference/crud.md
Normal file
20
docs/reference/crud.md
Normal file
@@ -0,0 +1,20 @@
|
||||
# `crud`
|
||||
|
||||
Here's the reference for the CRUD classes, factory, and search utilities.
|
||||
|
||||
You can import the main symbols from `fastapi_toolsets.crud`:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.crud import CrudFactory, AsyncCrud
|
||||
from fastapi_toolsets.crud.search import SearchConfig, get_searchable_fields, build_search_filters
|
||||
```
|
||||
|
||||
## ::: fastapi_toolsets.crud.factory.AsyncCrud
|
||||
|
||||
## ::: fastapi_toolsets.crud.factory.CrudFactory
|
||||
|
||||
## ::: fastapi_toolsets.crud.search.SearchConfig
|
||||
|
||||
## ::: fastapi_toolsets.crud.search.get_searchable_fields
|
||||
|
||||
## ::: fastapi_toolsets.crud.search.build_search_filters
|
||||
34
docs/reference/db.md
Normal file
34
docs/reference/db.md
Normal file
@@ -0,0 +1,34 @@
|
||||
# `db`
|
||||
|
||||
Here's the reference for all database session utilities, transaction helpers, and locking functions.
|
||||
|
||||
You can import them directly from `fastapi_toolsets.db`:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.db import (
|
||||
LockMode,
|
||||
cleanup_tables,
|
||||
create_database,
|
||||
create_db_dependency,
|
||||
create_db_context,
|
||||
get_transaction,
|
||||
lock_tables,
|
||||
wait_for_row_change,
|
||||
)
|
||||
```
|
||||
|
||||
## ::: fastapi_toolsets.db.LockMode
|
||||
|
||||
## ::: fastapi_toolsets.db.create_db_dependency
|
||||
|
||||
## ::: fastapi_toolsets.db.create_db_context
|
||||
|
||||
## ::: fastapi_toolsets.db.get_transaction
|
||||
|
||||
## ::: fastapi_toolsets.db.lock_tables
|
||||
|
||||
## ::: fastapi_toolsets.db.wait_for_row_change
|
||||
|
||||
## ::: fastapi_toolsets.db.create_database
|
||||
|
||||
## ::: fastapi_toolsets.db.cleanup_tables
|
||||
13
docs/reference/dependencies.md
Normal file
13
docs/reference/dependencies.md
Normal file
@@ -0,0 +1,13 @@
|
||||
# `dependencies`
|
||||
|
||||
Here's the reference for the FastAPI dependency factory functions.
|
||||
|
||||
You can import them directly from `fastapi_toolsets.dependencies`:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.dependencies import PathDependency, BodyDependency
|
||||
```
|
||||
|
||||
## ::: fastapi_toolsets.dependencies.PathDependency
|
||||
|
||||
## ::: fastapi_toolsets.dependencies.BodyDependency
|
||||
40
docs/reference/exceptions.md
Normal file
40
docs/reference/exceptions.md
Normal file
@@ -0,0 +1,40 @@
|
||||
# `exceptions`
|
||||
|
||||
Here's the reference for all exception classes and handler utilities.
|
||||
|
||||
You can import them directly from `fastapi_toolsets.exceptions`:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.exceptions import (
|
||||
ApiException,
|
||||
UnauthorizedError,
|
||||
ForbiddenError,
|
||||
NotFoundError,
|
||||
ConflictError,
|
||||
NoSearchableFieldsError,
|
||||
InvalidFacetFilterError,
|
||||
InvalidOrderFieldError,
|
||||
generate_error_responses,
|
||||
init_exceptions_handlers,
|
||||
)
|
||||
```
|
||||
|
||||
## ::: fastapi_toolsets.exceptions.exceptions.ApiException
|
||||
|
||||
## ::: fastapi_toolsets.exceptions.exceptions.UnauthorizedError
|
||||
|
||||
## ::: fastapi_toolsets.exceptions.exceptions.ForbiddenError
|
||||
|
||||
## ::: fastapi_toolsets.exceptions.exceptions.NotFoundError
|
||||
|
||||
## ::: fastapi_toolsets.exceptions.exceptions.ConflictError
|
||||
|
||||
## ::: fastapi_toolsets.exceptions.exceptions.NoSearchableFieldsError
|
||||
|
||||
## ::: fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError
|
||||
|
||||
## ::: fastapi_toolsets.exceptions.exceptions.InvalidOrderFieldError
|
||||
|
||||
## ::: fastapi_toolsets.exceptions.exceptions.generate_error_responses
|
||||
|
||||
## ::: fastapi_toolsets.exceptions.handler.init_exceptions_handlers
|
||||
31
docs/reference/fixtures.md
Normal file
31
docs/reference/fixtures.md
Normal file
@@ -0,0 +1,31 @@
|
||||
# `fixtures`
|
||||
|
||||
Here's the reference for the fixture registry, enums, and loading utilities.
|
||||
|
||||
You can import them directly from `fastapi_toolsets.fixtures`:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.fixtures import (
|
||||
Context,
|
||||
LoadStrategy,
|
||||
Fixture,
|
||||
FixtureRegistry,
|
||||
load_fixtures,
|
||||
load_fixtures_by_context,
|
||||
get_obj_by_attr,
|
||||
)
|
||||
```
|
||||
|
||||
## ::: fastapi_toolsets.fixtures.enum.Context
|
||||
|
||||
## ::: fastapi_toolsets.fixtures.enum.LoadStrategy
|
||||
|
||||
## ::: fastapi_toolsets.fixtures.registry.Fixture
|
||||
|
||||
## ::: fastapi_toolsets.fixtures.registry.FixtureRegistry
|
||||
|
||||
## ::: fastapi_toolsets.fixtures.utils.load_fixtures
|
||||
|
||||
## ::: fastapi_toolsets.fixtures.utils.load_fixtures_by_context
|
||||
|
||||
## ::: fastapi_toolsets.fixtures.utils.get_obj_by_attr
|
||||
13
docs/reference/logger.md
Normal file
13
docs/reference/logger.md
Normal file
@@ -0,0 +1,13 @@
|
||||
# `logger`
|
||||
|
||||
Here's the reference for the logging utilities.
|
||||
|
||||
You can import them directly from `fastapi_toolsets.logger`:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.logger import configure_logging, get_logger
|
||||
```
|
||||
|
||||
## ::: fastapi_toolsets.logger.configure_logging
|
||||
|
||||
## ::: fastapi_toolsets.logger.get_logger
|
||||
15
docs/reference/metrics.md
Normal file
15
docs/reference/metrics.md
Normal file
@@ -0,0 +1,15 @@
|
||||
# `metrics`
|
||||
|
||||
Here's the reference for the Prometheus metrics registry and endpoint handler.
|
||||
|
||||
You can import them directly from `fastapi_toolsets.metrics`:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.metrics import Metric, MetricsRegistry, init_metrics
|
||||
```
|
||||
|
||||
## ::: fastapi_toolsets.metrics.registry.Metric
|
||||
|
||||
## ::: fastapi_toolsets.metrics.registry.MetricsRegistry
|
||||
|
||||
## ::: fastapi_toolsets.metrics.handler.init_metrics
|
||||
34
docs/reference/models.md
Normal file
34
docs/reference/models.md
Normal file
@@ -0,0 +1,34 @@
|
||||
# `models`
|
||||
|
||||
Here's the reference for the SQLAlchemy model mixins provided by the `models` module.
|
||||
|
||||
You can import them directly from `fastapi_toolsets.models`:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.models import (
|
||||
ModelEvent,
|
||||
UUIDMixin,
|
||||
UUIDv7Mixin,
|
||||
CreatedAtMixin,
|
||||
UpdatedAtMixin,
|
||||
TimestampMixin,
|
||||
WatchedFieldsMixin,
|
||||
watch,
|
||||
)
|
||||
```
|
||||
|
||||
## ::: fastapi_toolsets.models.ModelEvent
|
||||
|
||||
## ::: fastapi_toolsets.models.UUIDMixin
|
||||
|
||||
## ::: fastapi_toolsets.models.UUIDv7Mixin
|
||||
|
||||
## ::: fastapi_toolsets.models.CreatedAtMixin
|
||||
|
||||
## ::: fastapi_toolsets.models.UpdatedAtMixin
|
||||
|
||||
## ::: fastapi_toolsets.models.TimestampMixin
|
||||
|
||||
## ::: fastapi_toolsets.models.WatchedFieldsMixin
|
||||
|
||||
## ::: fastapi_toolsets.models.watch
|
||||
26
docs/reference/pytest.md
Normal file
26
docs/reference/pytest.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# `pytest`
|
||||
|
||||
Here's the reference for all testing utilities and pytest fixtures.
|
||||
|
||||
You can import them directly from `fastapi_toolsets.pytest`:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.pytest import (
|
||||
register_fixtures,
|
||||
create_async_client,
|
||||
create_db_session,
|
||||
worker_database_url,
|
||||
create_worker_database,
|
||||
cleanup_tables,
|
||||
)
|
||||
```
|
||||
|
||||
## ::: fastapi_toolsets.pytest.plugin.register_fixtures
|
||||
|
||||
## ::: fastapi_toolsets.pytest.utils.create_async_client
|
||||
|
||||
## ::: fastapi_toolsets.pytest.utils.create_db_session
|
||||
|
||||
## ::: fastapi_toolsets.pytest.utils.worker_database_url
|
||||
|
||||
## ::: fastapi_toolsets.pytest.utils.create_worker_database
|
||||
46
docs/reference/schemas.md
Normal file
46
docs/reference/schemas.md
Normal file
@@ -0,0 +1,46 @@
|
||||
# `schemas`
|
||||
|
||||
Here's the reference for all response models and types provided by the `schemas` module.
|
||||
|
||||
You can import them directly from `fastapi_toolsets.schemas`:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.schemas import (
|
||||
PydanticBase,
|
||||
ResponseStatus,
|
||||
ApiError,
|
||||
BaseResponse,
|
||||
Response,
|
||||
ErrorResponse,
|
||||
OffsetPagination,
|
||||
CursorPagination,
|
||||
PaginationType,
|
||||
PaginatedResponse,
|
||||
OffsetPaginatedResponse,
|
||||
CursorPaginatedResponse,
|
||||
)
|
||||
```
|
||||
|
||||
## ::: fastapi_toolsets.schemas.PydanticBase
|
||||
|
||||
## ::: fastapi_toolsets.schemas.ResponseStatus
|
||||
|
||||
## ::: fastapi_toolsets.schemas.ApiError
|
||||
|
||||
## ::: fastapi_toolsets.schemas.BaseResponse
|
||||
|
||||
## ::: fastapi_toolsets.schemas.Response
|
||||
|
||||
## ::: fastapi_toolsets.schemas.ErrorResponse
|
||||
|
||||
## ::: fastapi_toolsets.schemas.OffsetPagination
|
||||
|
||||
## ::: fastapi_toolsets.schemas.CursorPagination
|
||||
|
||||
## ::: fastapi_toolsets.schemas.PaginationType
|
||||
|
||||
## ::: fastapi_toolsets.schemas.PaginatedResponse
|
||||
|
||||
## ::: fastapi_toolsets.schemas.OffsetPaginatedResponse
|
||||
|
||||
## ::: fastapi_toolsets.schemas.CursorPaginatedResponse
|
||||
0
docs_src/__init__.py
Normal file
0
docs_src/__init__.py
Normal file
0
docs_src/examples/__init__.py
Normal file
0
docs_src/examples/__init__.py
Normal file
0
docs_src/examples/pagination_search/__init__.py
Normal file
0
docs_src/examples/pagination_search/__init__.py
Normal file
9
docs_src/examples/pagination_search/app.py
Normal file
9
docs_src/examples/pagination_search/app.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
||||
|
||||
from .routes import router
|
||||
|
||||
app = FastAPI()
|
||||
init_exceptions_handlers(app=app)
|
||||
app.include_router(router=router)
|
||||
21
docs_src/examples/pagination_search/crud.py
Normal file
21
docs_src/examples/pagination_search/crud.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from fastapi_toolsets.crud import CrudFactory
|
||||
|
||||
from .models import Article, Category
|
||||
|
||||
ArticleCrud = CrudFactory(
|
||||
model=Article,
|
||||
cursor_column=Article.created_at,
|
||||
searchable_fields=[ # default fields for full-text search
|
||||
Article.title,
|
||||
Article.body,
|
||||
(Article.category, Category.name),
|
||||
],
|
||||
facet_fields=[ # fields exposed as filter dropdowns
|
||||
Article.status,
|
||||
(Article.category, Category.name),
|
||||
],
|
||||
order_fields=[ # fields exposed for client-driven ordering
|
||||
Article.title,
|
||||
Article.created_at,
|
||||
],
|
||||
)
|
||||
17
docs_src/examples/pagination_search/db.py
Normal file
17
docs_src/examples/pagination_search/db.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from fastapi_toolsets.db import create_db_context, create_db_dependency
|
||||
|
||||
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/postgres"
|
||||
|
||||
engine = create_async_engine(url=DATABASE_URL, future=True)
|
||||
async_session_maker = async_sessionmaker(bind=engine, expire_on_commit=False)
|
||||
|
||||
get_db = create_db_dependency(session_maker=async_session_maker)
|
||||
get_db_context = create_db_context(session_maker=async_session_maker)
|
||||
|
||||
|
||||
SessionDep = Annotated[AsyncSession, Depends(get_db)]
|
||||
34
docs_src/examples/pagination_search/models.py
Normal file
34
docs_src/examples/pagination_search/models.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import Boolean, ForeignKey, String, Text
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||
|
||||
from fastapi_toolsets.models import CreatedAtMixin
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class Category(Base):
|
||||
__tablename__ = "categories"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(primary_key=True, default=uuid.uuid4)
|
||||
name: Mapped[str] = mapped_column(String(64), unique=True)
|
||||
|
||||
articles: Mapped[list["Article"]] = relationship(back_populates="category")
|
||||
|
||||
|
||||
class Article(Base, CreatedAtMixin):
|
||||
__tablename__ = "articles"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(primary_key=True, default=uuid.uuid4)
|
||||
title: Mapped[str] = mapped_column(String(256))
|
||||
body: Mapped[str] = mapped_column(Text)
|
||||
status: Mapped[str] = mapped_column(String(32))
|
||||
published: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
category_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
ForeignKey("categories.id"), nullable=True
|
||||
)
|
||||
|
||||
category: Mapped["Category | None"] = relationship(back_populates="articles")
|
||||
89
docs_src/examples/pagination_search/routes.py
Normal file
89
docs_src/examples/pagination_search/routes.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from fastapi_toolsets.crud import OrderByClause
|
||||
from fastapi_toolsets.schemas import (
|
||||
CursorPaginatedResponse,
|
||||
OffsetPaginatedResponse,
|
||||
PaginatedResponse,
|
||||
)
|
||||
|
||||
from .crud import ArticleCrud
|
||||
from .db import SessionDep
|
||||
from .models import Article
|
||||
from .schemas import ArticleRead
|
||||
|
||||
router = APIRouter(prefix="/articles")
|
||||
|
||||
|
||||
@router.get("/offset")
|
||||
async def list_articles_offset(
|
||||
session: SessionDep,
|
||||
params: Annotated[
|
||||
dict,
|
||||
Depends(ArticleCrud.offset_params(default_page_size=20, max_page_size=100)),
|
||||
],
|
||||
filter_by: Annotated[dict[str, list[str]], Depends(ArticleCrud.filter_params())],
|
||||
order_by: Annotated[
|
||||
OrderByClause | None,
|
||||
Depends(ArticleCrud.order_params(default_field=Article.created_at)),
|
||||
],
|
||||
search: str | None = None,
|
||||
) -> OffsetPaginatedResponse[ArticleRead]:
|
||||
return await ArticleCrud.offset_paginate(
|
||||
session=session,
|
||||
**params,
|
||||
search=search,
|
||||
filter_by=filter_by or None,
|
||||
order_by=order_by,
|
||||
schema=ArticleRead,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/cursor")
|
||||
async def list_articles_cursor(
|
||||
session: SessionDep,
|
||||
params: Annotated[
|
||||
dict,
|
||||
Depends(ArticleCrud.cursor_params(default_page_size=20, max_page_size=100)),
|
||||
],
|
||||
filter_by: Annotated[dict[str, list[str]], Depends(ArticleCrud.filter_params())],
|
||||
order_by: Annotated[
|
||||
OrderByClause | None,
|
||||
Depends(ArticleCrud.order_params(default_field=Article.created_at)),
|
||||
],
|
||||
search: str | None = None,
|
||||
) -> CursorPaginatedResponse[ArticleRead]:
|
||||
return await ArticleCrud.cursor_paginate(
|
||||
session=session,
|
||||
**params,
|
||||
search=search,
|
||||
filter_by=filter_by or None,
|
||||
order_by=order_by,
|
||||
schema=ArticleRead,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_articles(
|
||||
session: SessionDep,
|
||||
params: Annotated[
|
||||
dict,
|
||||
Depends(ArticleCrud.paginate_params(default_page_size=20, max_page_size=100)),
|
||||
],
|
||||
filter_by: Annotated[dict[str, list[str]], Depends(ArticleCrud.filter_params())],
|
||||
order_by: Annotated[
|
||||
OrderByClause | None,
|
||||
Depends(ArticleCrud.order_params(default_field=Article.created_at)),
|
||||
],
|
||||
search: str | None = None,
|
||||
) -> PaginatedResponse[ArticleRead]:
|
||||
return await ArticleCrud.paginate(
|
||||
session,
|
||||
**params,
|
||||
search=search,
|
||||
filter_by=filter_by or None,
|
||||
order_by=order_by,
|
||||
schema=ArticleRead,
|
||||
)
|
||||
13
docs_src/examples/pagination_search/schemas.py
Normal file
13
docs_src/examples/pagination_search/schemas.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
from fastapi_toolsets.schemas import PydanticBase
|
||||
|
||||
|
||||
class ArticleRead(PydanticBase):
|
||||
id: uuid.UUID
|
||||
created_at: datetime.datetime
|
||||
title: str
|
||||
status: str
|
||||
published: bool
|
||||
category_id: uuid.UUID | None
|
||||
@@ -1,7 +1,7 @@
|
||||
[project]
|
||||
name = "fastapi-toolsets"
|
||||
version = "0.3.0"
|
||||
description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL"
|
||||
version = "2.4.2"
|
||||
description = "Production-ready utilities for FastAPI applications"
|
||||
readme = "README.md"
|
||||
license = "MIT"
|
||||
license-files = ["LICENSE"]
|
||||
@@ -11,7 +11,7 @@ authors = [
|
||||
]
|
||||
keywords = ["fastapi", "sqlalchemy", "postgresql"]
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Development Status :: 5 - Production/Stable",
|
||||
"Framework :: AsyncIO",
|
||||
"Framework :: FastAPI",
|
||||
"Framework :: Pydantic",
|
||||
@@ -24,18 +24,17 @@ classifiers = [
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Programming Language :: Python :: 3.14",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"Topic :: Software Development :: Libraries",
|
||||
"Topic :: Software Development",
|
||||
"Typing :: Typed",
|
||||
]
|
||||
dependencies = [
|
||||
"fastapi>=0.100.0",
|
||||
"sqlalchemy[asyncio]>=2.0",
|
||||
"asyncpg>=0.29.0",
|
||||
"fastapi>=0.100.0",
|
||||
"pydantic>=2.0",
|
||||
"typer>=0.9.0",
|
||||
"httpx>=0.25.0",
|
||||
"sqlalchemy[asyncio]>=2.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
@@ -45,23 +44,47 @@ Repository = "https://github.com/d3vyce/fastapi-toolsets"
|
||||
Issues = "https://github.com/d3vyce/fastapi-toolsets/issues"
|
||||
|
||||
[project.optional-dependencies]
|
||||
test = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-anyio>=0.0.0",
|
||||
"coverage>=7.0.0",
|
||||
"pytest-cov>=4.0.0",
|
||||
cli = [
|
||||
"typer>=0.9.0",
|
||||
]
|
||||
dev = [
|
||||
"fastapi-toolsets[test]",
|
||||
"ruff>=0.1.0",
|
||||
"ty>=0.0.1a0",
|
||||
metrics = [
|
||||
"prometheus_client>=0.20.0",
|
||||
]
|
||||
pytest = [
|
||||
"httpx>=0.25.0",
|
||||
"pytest-xdist>=3.0.0",
|
||||
"pytest>=8.0.0",
|
||||
]
|
||||
all = [
|
||||
"fastapi-toolsets[cli,metrics,pytest]",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
fastapi-toolsets = "fastapi_toolsets.cli:app"
|
||||
manager = "fastapi_toolsets.cli.app:cli"
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
{include-group = "tests"},
|
||||
{include-group = "docs"},
|
||||
"fastapi-toolsets[all]",
|
||||
"ruff>=0.1.0",
|
||||
"ty>=0.0.1a0",
|
||||
]
|
||||
tests = [
|
||||
"coverage>=7.0.0",
|
||||
"httpx>=0.25.0",
|
||||
"pytest-anyio>=0.0.0",
|
||||
"pytest-cov>=4.0.0",
|
||||
"pytest-xdist>=3.0.0",
|
||||
"pytest>=8.0.0",
|
||||
]
|
||||
docs = [
|
||||
"mkdocstrings-python>=2.0.2",
|
||||
"zensical>=0.0.23",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.9.26,<0.10.0"]
|
||||
requires = ["uv_build>=0.10,<0.11.0"]
|
||||
build-backend = "uv_build"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
||||
@@ -21,4 +21,4 @@ Example usage:
|
||||
return Response(data={"user": user.username}, message="Success")
|
||||
"""
|
||||
|
||||
__version__ = "0.3.0"
|
||||
__version__ = "2.4.2"
|
||||
|
||||
9
src/fastapi_toolsets/_imports.py
Normal file
9
src/fastapi_toolsets/_imports.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Optional dependency helpers."""
|
||||
|
||||
|
||||
def require_extra(package: str, extra: str) -> None:
|
||||
"""Raise *ImportError* with an actionable install instruction."""
|
||||
raise ImportError(
|
||||
f"'{package}' is required to use this module. "
|
||||
f"Install it with: pip install fastapi-toolsets[{extra}]"
|
||||
)
|
||||
@@ -1,5 +1,5 @@
|
||||
"""CLI for FastAPI projects."""
|
||||
|
||||
from .app import app, register_command
|
||||
from .utils import async_command
|
||||
|
||||
__all__ = ["app", "register_command"]
|
||||
__all__ = ["async_command"]
|
||||
|
||||
@@ -1,97 +1,37 @@
|
||||
"""Main CLI application."""
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
try:
|
||||
import typer
|
||||
except ImportError:
|
||||
from .._imports import require_extra
|
||||
|
||||
import typer
|
||||
require_extra(package="typer", extra="cli")
|
||||
|
||||
from .commands import fixtures
|
||||
from ..logger import configure_logging
|
||||
from .config import get_custom_cli
|
||||
from .pyproject import load_pyproject
|
||||
|
||||
app = typer.Typer(
|
||||
name="fastapi-utils",
|
||||
help="CLI utilities for FastAPI projects.",
|
||||
no_args_is_help=True,
|
||||
)
|
||||
# Use custom CLI if configured, otherwise create default one
|
||||
_custom_cli = get_custom_cli()
|
||||
|
||||
# Register built-in commands
|
||||
app.add_typer(fixtures.app, name="fixtures")
|
||||
if _custom_cli is not None:
|
||||
cli = _custom_cli
|
||||
else:
|
||||
cli = typer.Typer(
|
||||
name="manager",
|
||||
help="CLI utilities for FastAPI projects.",
|
||||
no_args_is_help=True,
|
||||
)
|
||||
|
||||
_config = load_pyproject()
|
||||
if _config.get("fixtures") and _config.get("db_context"):
|
||||
from .commands.fixtures import fixture_cli
|
||||
|
||||
cli.add_typer(fixture_cli, name="fixtures")
|
||||
|
||||
|
||||
def register_command(command: typer.Typer, name: str) -> None:
|
||||
"""Register a custom command group.
|
||||
|
||||
Args:
|
||||
command: Typer app for the command group
|
||||
name: Name for the command group
|
||||
|
||||
Example:
|
||||
# In your project's cli.py:
|
||||
import typer
|
||||
from fastapi_toolsets.cli import app, register_command
|
||||
|
||||
my_commands = typer.Typer()
|
||||
|
||||
@my_commands.command()
|
||||
def seed():
|
||||
'''Seed the database.'''
|
||||
...
|
||||
|
||||
register_command(my_commands, "db")
|
||||
# Now available as: fastapi-utils db seed
|
||||
"""
|
||||
app.add_typer(command, name=name)
|
||||
|
||||
|
||||
@app.callback()
|
||||
def main(
|
||||
ctx: typer.Context,
|
||||
config: Annotated[
|
||||
Path | None,
|
||||
typer.Option(
|
||||
"--config",
|
||||
"-c",
|
||||
help="Path to project config file (Python module with fixtures registry).",
|
||||
envvar="FASTAPI_TOOLSETS_CONFIG",
|
||||
),
|
||||
] = None,
|
||||
) -> None:
|
||||
@cli.callback()
|
||||
def main(ctx: typer.Context) -> None:
|
||||
"""FastAPI utilities CLI."""
|
||||
configure_logging()
|
||||
ctx.ensure_object(dict)
|
||||
|
||||
if config:
|
||||
ctx.obj["config_path"] = config
|
||||
# Load the config module
|
||||
config_module = _load_module_from_path(config)
|
||||
ctx.obj["config_module"] = config_module
|
||||
|
||||
|
||||
def _load_module_from_path(path: Path) -> object:
|
||||
"""Load a Python module from a file path.
|
||||
|
||||
Handles both absolute and relative imports by adding the config's
|
||||
parent directory to sys.path temporarily.
|
||||
"""
|
||||
path = path.resolve()
|
||||
|
||||
# Add the parent directory to sys.path to support relative imports
|
||||
parent_dir = str(
|
||||
path.parent.parent
|
||||
) # Go up two levels (e.g., from app/cli_config.py to project root)
|
||||
if parent_dir not in sys.path:
|
||||
sys.path.insert(0, parent_dir)
|
||||
|
||||
# Also add immediate parent for direct module imports
|
||||
immediate_parent = str(path.parent)
|
||||
if immediate_parent not in sys.path:
|
||||
sys.path.insert(0, immediate_parent)
|
||||
|
||||
spec = importlib.util.spec_from_file_location("config", path)
|
||||
if spec is None or spec.loader is None:
|
||||
raise typer.BadParameter(f"Cannot load module from {path}")
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules["config"] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
@@ -1,138 +1,66 @@
|
||||
"""Fixture management commands."""
|
||||
|
||||
import asyncio
|
||||
from typing import Annotated
|
||||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
from ...fixtures import Context, FixtureRegistry, LoadStrategy, load_fixtures_by_context
|
||||
from ...fixtures import Context, LoadStrategy, load_fixtures_by_context
|
||||
from ..config import get_db_context, get_fixtures_registry
|
||||
from ..utils import async_command
|
||||
|
||||
app = typer.Typer(
|
||||
fixture_cli = typer.Typer(
|
||||
name="fixtures",
|
||||
help="Manage database fixtures.",
|
||||
no_args_is_help=True,
|
||||
)
|
||||
console = Console()
|
||||
|
||||
|
||||
def _get_registry(ctx: typer.Context) -> FixtureRegistry:
|
||||
"""Get fixture registry from context."""
|
||||
config = ctx.obj.get("config_module") if ctx.obj else None
|
||||
if config is None:
|
||||
raise typer.BadParameter(
|
||||
"No config provided. Use --config to specify a config file with a 'fixtures' registry."
|
||||
)
|
||||
|
||||
registry = getattr(config, "fixtures", None)
|
||||
if registry is None:
|
||||
raise typer.BadParameter(
|
||||
"Config module must have a 'fixtures' attribute (FixtureRegistry instance)."
|
||||
)
|
||||
|
||||
if not isinstance(registry, FixtureRegistry):
|
||||
raise typer.BadParameter(
|
||||
f"'fixtures' must be a FixtureRegistry instance, got {type(registry).__name__}"
|
||||
)
|
||||
|
||||
return registry
|
||||
|
||||
|
||||
def _get_db_context(ctx: typer.Context):
|
||||
"""Get database context manager from config."""
|
||||
config = ctx.obj.get("config_module") if ctx.obj else None
|
||||
if config is None:
|
||||
raise typer.BadParameter("No config provided.")
|
||||
|
||||
get_db_context = getattr(config, "get_db_context", None)
|
||||
if get_db_context is None:
|
||||
raise typer.BadParameter("Config module must have a 'get_db_context' function.")
|
||||
|
||||
return get_db_context
|
||||
|
||||
|
||||
@app.command("list")
|
||||
@fixture_cli.command("list")
|
||||
def list_fixtures(
|
||||
ctx: typer.Context,
|
||||
context: Annotated[
|
||||
str | None,
|
||||
Context | None,
|
||||
typer.Option(
|
||||
"--context",
|
||||
"-c",
|
||||
help="Filter by context (base, production, development, testing).",
|
||||
help="Filter by context.",
|
||||
),
|
||||
] = None,
|
||||
) -> None:
|
||||
"""List all registered fixtures."""
|
||||
registry = _get_registry(ctx)
|
||||
|
||||
if context:
|
||||
fixtures = registry.get_by_context(context)
|
||||
else:
|
||||
fixtures = registry.get_all()
|
||||
registry = get_fixtures_registry()
|
||||
fixtures = registry.get_by_context(context.value) if context else registry.get_all()
|
||||
|
||||
if not fixtures:
|
||||
typer.echo("No fixtures found.")
|
||||
print("No fixtures found.")
|
||||
return
|
||||
|
||||
typer.echo(f"\n{'Name':<30} {'Contexts':<30} {'Dependencies'}")
|
||||
typer.echo("-" * 80)
|
||||
table = Table("Name", "Contexts", "Dependencies")
|
||||
|
||||
for fixture in fixtures:
|
||||
contexts = ", ".join(fixture.contexts)
|
||||
deps = ", ".join(fixture.depends_on) if fixture.depends_on else "-"
|
||||
typer.echo(f"{fixture.name:<30} {contexts:<30} {deps}")
|
||||
table.add_row(fixture.name, contexts, deps)
|
||||
|
||||
typer.echo(f"\nTotal: {len(fixtures)} fixture(s)")
|
||||
console.print(table)
|
||||
print(f"\nTotal: {len(fixtures)} fixture(s)")
|
||||
|
||||
|
||||
@app.command("graph")
|
||||
def show_graph(
|
||||
ctx: typer.Context,
|
||||
fixture_name: Annotated[
|
||||
str | None,
|
||||
typer.Argument(help="Show dependencies for a specific fixture."),
|
||||
] = None,
|
||||
) -> None:
|
||||
"""Show fixture dependency graph."""
|
||||
registry = _get_registry(ctx)
|
||||
|
||||
if fixture_name:
|
||||
try:
|
||||
order = registry.resolve_dependencies(fixture_name)
|
||||
typer.echo(f"\nDependency chain for '{fixture_name}':\n")
|
||||
for i, name in enumerate(order):
|
||||
indent = " " * i
|
||||
arrow = "└─> " if i > 0 else ""
|
||||
typer.echo(f"{indent}{arrow}{name}")
|
||||
except KeyError:
|
||||
typer.echo(f"Fixture '{fixture_name}' not found.", err=True)
|
||||
raise typer.Exit(1)
|
||||
else:
|
||||
# Show full graph
|
||||
fixtures = registry.get_all()
|
||||
|
||||
typer.echo("\nFixture Dependency Graph:\n")
|
||||
for fixture in fixtures:
|
||||
deps = (
|
||||
f" -> [{', '.join(fixture.depends_on)}]" if fixture.depends_on else ""
|
||||
)
|
||||
typer.echo(f" {fixture.name}{deps}")
|
||||
|
||||
|
||||
@app.command("load")
|
||||
def load(
|
||||
@fixture_cli.command("load")
|
||||
@async_command
|
||||
async def load(
|
||||
ctx: typer.Context,
|
||||
contexts: Annotated[
|
||||
list[str] | None,
|
||||
typer.Argument(
|
||||
help="Contexts to load (base, production, development, testing)."
|
||||
),
|
||||
list[Context] | None,
|
||||
typer.Argument(help="Contexts to load."),
|
||||
] = None,
|
||||
strategy: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
"--strategy", "-s", help="Load strategy: merge, insert, skip_existing."
|
||||
),
|
||||
] = "merge",
|
||||
LoadStrategy,
|
||||
typer.Option("--strategy", "-s", help="Load strategy."),
|
||||
] = LoadStrategy.MERGE,
|
||||
dry_run: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
@@ -141,85 +69,32 @@ def load(
|
||||
] = False,
|
||||
) -> None:
|
||||
"""Load fixtures into the database."""
|
||||
registry = _get_registry(ctx)
|
||||
get_db_context = _get_db_context(ctx)
|
||||
registry = get_fixtures_registry()
|
||||
db_context = get_db_context()
|
||||
|
||||
# Parse contexts
|
||||
if contexts:
|
||||
context_list = contexts
|
||||
else:
|
||||
context_list = [Context.BASE]
|
||||
context_list = list(contexts) if contexts else [Context.BASE]
|
||||
|
||||
# Parse strategy
|
||||
try:
|
||||
load_strategy = LoadStrategy(strategy)
|
||||
except ValueError:
|
||||
typer.echo(
|
||||
f"Invalid strategy: {strategy}. Use: merge, insert, skip_existing", err=True
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Resolve what will be loaded
|
||||
ordered = registry.resolve_context_dependencies(*context_list)
|
||||
|
||||
if not ordered:
|
||||
typer.echo("No fixtures to load for the specified context(s).")
|
||||
print("No fixtures to load for the specified context(s).")
|
||||
return
|
||||
|
||||
typer.echo(f"\nFixtures to load ({load_strategy.value} strategy):")
|
||||
print(f"\nFixtures to load ({strategy.value} strategy):")
|
||||
for name in ordered:
|
||||
fixture = registry.get(name)
|
||||
instances = list(fixture.func())
|
||||
model_name = type(instances[0]).__name__ if instances else "?"
|
||||
typer.echo(f" - {name}: {len(instances)} {model_name}(s)")
|
||||
print(f" - {name}: {len(instances)} {model_name}(s)")
|
||||
|
||||
if dry_run:
|
||||
typer.echo("\n[Dry run - no changes made]")
|
||||
print("\n[Dry run - no changes made]")
|
||||
return
|
||||
|
||||
typer.echo("\nLoading...")
|
||||
|
||||
async def do_load():
|
||||
async with get_db_context() as session:
|
||||
result = await load_fixtures_by_context(
|
||||
session, registry, *context_list, strategy=load_strategy
|
||||
)
|
||||
return result
|
||||
|
||||
result = asyncio.run(do_load())
|
||||
async with db_context() as session:
|
||||
result = await load_fixtures_by_context(
|
||||
session, registry, *context_list, strategy=strategy
|
||||
)
|
||||
|
||||
total = sum(len(items) for items in result.values())
|
||||
typer.echo(f"\nLoaded {total} record(s) successfully.")
|
||||
|
||||
|
||||
@app.command("show")
|
||||
def show_fixture(
|
||||
ctx: typer.Context,
|
||||
name: Annotated[str, typer.Argument(help="Fixture name to show.")],
|
||||
) -> None:
|
||||
"""Show details of a specific fixture."""
|
||||
registry = _get_registry(ctx)
|
||||
|
||||
try:
|
||||
fixture = registry.get(name)
|
||||
except KeyError:
|
||||
typer.echo(f"Fixture '{name}' not found.", err=True)
|
||||
raise typer.Exit(1)
|
||||
|
||||
typer.echo(f"\nFixture: {fixture.name}")
|
||||
typer.echo(f"Contexts: {', '.join(fixture.contexts)}")
|
||||
typer.echo(
|
||||
f"Dependencies: {', '.join(fixture.depends_on) if fixture.depends_on else 'None'}"
|
||||
)
|
||||
|
||||
# Show instances
|
||||
instances = list(fixture.func())
|
||||
if instances:
|
||||
model_name = type(instances[0]).__name__
|
||||
typer.echo(f"\nInstances ({len(instances)} {model_name}):")
|
||||
for instance in instances[:10]: # Limit to 10
|
||||
typer.echo(f" - {instance!r}")
|
||||
if len(instances) > 10:
|
||||
typer.echo(f" ... and {len(instances) - 10} more")
|
||||
else:
|
||||
typer.echo("\nNo instances (empty fixture)")
|
||||
print(f"\nLoaded {total} record(s) successfully.")
|
||||
|
||||
125
src/fastapi_toolsets/cli/config.py
Normal file
125
src/fastapi_toolsets/cli/config.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""CLI configuration and dynamic imports."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any, Literal, overload
|
||||
|
||||
import typer
|
||||
|
||||
from .pyproject import find_pyproject, load_pyproject
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..fixtures import FixtureRegistry
|
||||
|
||||
|
||||
def _ensure_project_in_path():
|
||||
"""Add project root to sys.path if not installed in editable mode."""
|
||||
pyproject = find_pyproject()
|
||||
if pyproject:
|
||||
project_root = str(pyproject.parent)
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
|
||||
def import_from_string(import_path: str) -> Any:
|
||||
"""Import an object from a dotted string path.
|
||||
|
||||
Args:
|
||||
import_path: Import path in ``"module.submodule:attribute"`` format
|
||||
|
||||
Returns:
|
||||
The imported attribute
|
||||
|
||||
Raises:
|
||||
typer.BadParameter: If the import path is invalid or import fails
|
||||
"""
|
||||
if ":" not in import_path:
|
||||
raise typer.BadParameter(
|
||||
f"Invalid import path '{import_path}'. Expected format: 'module:attribute'"
|
||||
)
|
||||
|
||||
module_path, attr_name = import_path.rsplit(":", 1)
|
||||
|
||||
_ensure_project_in_path()
|
||||
|
||||
try:
|
||||
module = importlib.import_module(module_path)
|
||||
except ImportError as e:
|
||||
raise typer.BadParameter(f"Cannot import module '{module_path}': {e}")
|
||||
|
||||
if not hasattr(module, attr_name):
|
||||
raise typer.BadParameter(
|
||||
f"Module '{module_path}' has no attribute '{attr_name}'"
|
||||
)
|
||||
|
||||
return getattr(module, attr_name)
|
||||
|
||||
|
||||
@overload
|
||||
def get_config_value(key: str, required: Literal[True]) -> Any: ... # pragma: no cover
|
||||
@overload
|
||||
def get_config_value(
|
||||
key: str, required: bool = False
|
||||
) -> Any | None: ... # pragma: no cover
|
||||
def get_config_value(key: str, required: bool = False) -> Any | None:
|
||||
"""Get a configuration value from pyproject.toml.
|
||||
|
||||
Args:
|
||||
key: The configuration key in [tool.fastapi-toolsets].
|
||||
required: If True, raises an error when the key is missing.
|
||||
|
||||
Returns:
|
||||
The configuration value, or None if not found and not required.
|
||||
|
||||
Raises:
|
||||
typer.BadParameter: If required=True and the key is missing.
|
||||
"""
|
||||
config = load_pyproject()
|
||||
value = config.get(key)
|
||||
|
||||
if required and value is None:
|
||||
raise typer.BadParameter(
|
||||
f"No '{key}' configured. "
|
||||
f"Add '{key}' to [tool.fastapi-toolsets] in pyproject.toml."
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def get_fixtures_registry() -> FixtureRegistry:
|
||||
"""Import and return the fixtures registry from config."""
|
||||
from ..fixtures import FixtureRegistry
|
||||
|
||||
import_path = get_config_value("fixtures", required=True)
|
||||
registry = import_from_string(import_path)
|
||||
|
||||
if not isinstance(registry, FixtureRegistry):
|
||||
raise typer.BadParameter(
|
||||
f"'fixtures' must be a FixtureRegistry instance, got {type(registry).__name__}"
|
||||
)
|
||||
|
||||
return registry
|
||||
|
||||
|
||||
def get_db_context() -> Any:
|
||||
"""Import and return the db_context function from config."""
|
||||
import_path = get_config_value("db_context", required=True)
|
||||
return import_from_string(import_path)
|
||||
|
||||
|
||||
def get_custom_cli() -> typer.Typer | None:
|
||||
"""Import and return the custom CLI Typer instance from config."""
|
||||
import_path = get_config_value("custom_cli")
|
||||
if not import_path:
|
||||
return None
|
||||
|
||||
custom = import_from_string(import_path)
|
||||
|
||||
if not isinstance(custom, typer.Typer):
|
||||
raise typer.BadParameter(
|
||||
f"'custom_cli' must be a Typer instance, got {type(custom).__name__}"
|
||||
)
|
||||
|
||||
return custom
|
||||
49
src/fastapi_toolsets/cli/pyproject.py
Normal file
49
src/fastapi_toolsets/cli/pyproject.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""Pyproject.toml discovery and loading."""
|
||||
|
||||
import tomllib
|
||||
from pathlib import Path
|
||||
|
||||
TOOL_NAME = "fastapi-toolsets"
|
||||
|
||||
|
||||
def find_pyproject(start_path: Path | None = None) -> Path | None:
|
||||
"""Find pyproject.toml by walking up the directory tree.
|
||||
|
||||
Similar to how pytest, black, and ruff discover their config files.
|
||||
|
||||
Args:
|
||||
start_path: Directory to start searching from. Defaults to cwd.
|
||||
|
||||
Returns:
|
||||
Path to pyproject.toml, or None if not found.
|
||||
"""
|
||||
path = (start_path or Path.cwd()).resolve()
|
||||
|
||||
for directory in [path, *path.parents]:
|
||||
pyproject = directory / "pyproject.toml"
|
||||
if pyproject.is_file():
|
||||
return pyproject
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def load_pyproject(path: Path | None = None) -> dict:
|
||||
"""Load tool configuration from pyproject.toml.
|
||||
|
||||
Args:
|
||||
path: Explicit path to pyproject.toml. If None, searches up from cwd.
|
||||
|
||||
Returns:
|
||||
The [tool.fastapi-toolsets] section as a dict, or empty dict if not found.
|
||||
"""
|
||||
pyproject_path = path or find_pyproject()
|
||||
|
||||
if not pyproject_path:
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(pyproject_path, "rb") as f:
|
||||
data = tomllib.load(f)
|
||||
return data.get("tool", {}).get(TOOL_NAME, {})
|
||||
except (OSError, tomllib.TOMLDecodeError):
|
||||
return {}
|
||||
29
src/fastapi_toolsets/cli/utils.py
Normal file
29
src/fastapi_toolsets/cli/utils.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""CLI utility functions."""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def async_command(func: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, T]:
|
||||
"""Decorator to run an async function as a sync CLI command.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@fixture_cli.command("load")
|
||||
@async_command
|
||||
async def load(ctx: typer.Context) -> None:
|
||||
async with get_db_context() as session:
|
||||
await load_fixtures(session, registry)
|
||||
```
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
return asyncio.run(func(*args, **kwargs))
|
||||
|
||||
return wrapper
|
||||
@@ -1,378 +0,0 @@
|
||||
"""Generic async CRUD operations for SQLAlchemy models."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, ClassVar, Generic, Self, TypeVar, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_, func, select
|
||||
from sqlalchemy import delete as sql_delete
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.sql.roles import WhereHavingRole
|
||||
|
||||
from .db import get_transaction
|
||||
from .exceptions import NotFoundError
|
||||
|
||||
__all__ = [
|
||||
"AsyncCrud",
|
||||
"CrudFactory",
|
||||
]
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
||||
|
||||
|
||||
class AsyncCrud(Generic[ModelType]):
|
||||
"""Generic async CRUD operations for SQLAlchemy models.
|
||||
|
||||
Subclass this and set the `model` class variable, or use `CrudFactory`.
|
||||
|
||||
Example:
|
||||
class UserCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
|
||||
# Or use the factory:
|
||||
UserCrud = CrudFactory(User)
|
||||
|
||||
# Then use it:
|
||||
user = await UserCrud.get(session, [User.id == 1])
|
||||
users = await UserCrud.get_multi(session, limit=10)
|
||||
"""
|
||||
|
||||
model: ClassVar[type[DeclarativeBase]]
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
) -> ModelType:
|
||||
"""Create a new record in the database.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
obj: Pydantic model with data to create
|
||||
|
||||
Returns:
|
||||
Created model instance
|
||||
"""
|
||||
async with get_transaction(session):
|
||||
db_model = cls.model(**obj.model_dump())
|
||||
session.add(db_model)
|
||||
await session.refresh(db_model)
|
||||
return cast(ModelType, db_model)
|
||||
|
||||
@classmethod
|
||||
async def get(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
*,
|
||||
with_for_update: bool = False,
|
||||
load_options: list[Any] | None = None,
|
||||
) -> ModelType:
|
||||
"""Get exactly one record. Raises NotFoundError if not found.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
with_for_update: Lock the row for update
|
||||
load_options: SQLAlchemy loader options (e.g., selectinload)
|
||||
|
||||
Returns:
|
||||
Model instance
|
||||
|
||||
Raises:
|
||||
NotFoundError: If no record found
|
||||
MultipleResultsFound: If more than one record found
|
||||
"""
|
||||
q = select(cls.model).where(and_(*filters))
|
||||
if load_options:
|
||||
q = q.options(*load_options)
|
||||
if with_for_update:
|
||||
q = q.with_for_update()
|
||||
result = await session.execute(q)
|
||||
item = result.unique().scalar_one_or_none()
|
||||
if not item:
|
||||
raise NotFoundError()
|
||||
return cast(ModelType, item)
|
||||
|
||||
@classmethod
|
||||
async def first(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any] | None = None,
|
||||
*,
|
||||
load_options: list[Any] | None = None,
|
||||
) -> ModelType | None:
|
||||
"""Get the first matching record, or None.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
load_options: SQLAlchemy loader options
|
||||
|
||||
Returns:
|
||||
Model instance or None
|
||||
"""
|
||||
q = select(cls.model)
|
||||
if filters:
|
||||
q = q.where(and_(*filters))
|
||||
if load_options:
|
||||
q = q.options(*load_options)
|
||||
result = await session.execute(q)
|
||||
return cast(ModelType | None, result.unique().scalars().first())
|
||||
|
||||
@classmethod
|
||||
async def get_multi(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
*,
|
||||
filters: list[Any] | None = None,
|
||||
load_options: list[Any] | None = None,
|
||||
order_by: Any | None = None,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> Sequence[ModelType]:
|
||||
"""Get multiple records from the database.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
load_options: SQLAlchemy loader options
|
||||
order_by: Column or list of columns to order by
|
||||
limit: Max number of rows to return
|
||||
offset: Rows to skip
|
||||
|
||||
Returns:
|
||||
List of model instances
|
||||
"""
|
||||
q = select(cls.model)
|
||||
if filters:
|
||||
q = q.where(and_(*filters))
|
||||
if load_options:
|
||||
q = q.options(*load_options)
|
||||
if order_by is not None:
|
||||
q = q.order_by(order_by)
|
||||
if offset is not None:
|
||||
q = q.offset(offset)
|
||||
if limit is not None:
|
||||
q = q.limit(limit)
|
||||
result = await session.execute(q)
|
||||
return cast(Sequence[ModelType], result.unique().scalars().all())
|
||||
|
||||
@classmethod
|
||||
async def update(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
filters: list[Any],
|
||||
*,
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = False,
|
||||
) -> ModelType:
|
||||
"""Update a record in the database.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
obj: Pydantic model with update data
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
exclude_unset: Exclude fields not explicitly set in the schema
|
||||
exclude_none: Exclude fields with None value
|
||||
|
||||
Returns:
|
||||
Updated model instance
|
||||
|
||||
Raises:
|
||||
NotFoundError: If no record found
|
||||
"""
|
||||
async with get_transaction(session):
|
||||
db_model = await cls.get(session=session, filters=filters)
|
||||
values = obj.model_dump(
|
||||
exclude_unset=exclude_unset, exclude_none=exclude_none
|
||||
)
|
||||
for key, value in values.items():
|
||||
setattr(db_model, key, value)
|
||||
await session.refresh(db_model)
|
||||
return db_model
|
||||
|
||||
@classmethod
|
||||
async def upsert(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
index_elements: list[str],
|
||||
*,
|
||||
set_: BaseModel | None = None,
|
||||
where: WhereHavingRole | None = None,
|
||||
) -> ModelType | None:
|
||||
"""Create or update a record (PostgreSQL only).
|
||||
|
||||
Uses INSERT ... ON CONFLICT for atomic upsert.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
obj: Pydantic model with data
|
||||
index_elements: Columns for ON CONFLICT (unique constraint)
|
||||
set_: Pydantic model for ON CONFLICT DO UPDATE SET
|
||||
where: WHERE clause for ON CONFLICT DO UPDATE
|
||||
|
||||
Returns:
|
||||
Model instance
|
||||
"""
|
||||
async with get_transaction(session):
|
||||
values = obj.model_dump(exclude_unset=True)
|
||||
q = insert(cls.model).values(**values)
|
||||
if set_:
|
||||
q = q.on_conflict_do_update(
|
||||
index_elements=index_elements,
|
||||
set_=set_.model_dump(exclude_unset=True),
|
||||
where=where,
|
||||
)
|
||||
else:
|
||||
q = q.on_conflict_do_nothing(index_elements=index_elements)
|
||||
q = q.returning(cls.model)
|
||||
result = await session.execute(q)
|
||||
try:
|
||||
db_model = result.unique().scalar_one()
|
||||
except NoResultFound:
|
||||
db_model = await cls.first(
|
||||
session=session,
|
||||
filters=[getattr(cls.model, k) == v for k, v in values.items()],
|
||||
)
|
||||
return cast(ModelType | None, db_model)
|
||||
|
||||
@classmethod
|
||||
async def delete(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
) -> bool:
|
||||
"""Delete records from the database.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
|
||||
Returns:
|
||||
True if deletion was executed
|
||||
"""
|
||||
async with get_transaction(session):
|
||||
q = sql_delete(cls.model).where(and_(*filters))
|
||||
await session.execute(q)
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def count(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any] | None = None,
|
||||
) -> int:
|
||||
"""Count records matching the filters.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
|
||||
Returns:
|
||||
Number of matching records
|
||||
"""
|
||||
q = select(func.count()).select_from(cls.model)
|
||||
if filters:
|
||||
q = q.where(and_(*filters))
|
||||
result = await session.execute(q)
|
||||
return result.scalar_one()
|
||||
|
||||
@classmethod
|
||||
async def exists(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
) -> bool:
|
||||
"""Check if a record exists.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
|
||||
Returns:
|
||||
True if at least one record matches
|
||||
"""
|
||||
q = select(cls.model).where(and_(*filters)).exists().select()
|
||||
result = await session.execute(q)
|
||||
return bool(result.scalar())
|
||||
|
||||
@classmethod
|
||||
async def paginate(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
*,
|
||||
filters: list[Any] | None = None,
|
||||
load_options: list[Any] | None = None,
|
||||
order_by: Any | None = None,
|
||||
page: int = 1,
|
||||
items_per_page: int = 20,
|
||||
) -> dict[str, Any]:
|
||||
"""Get paginated results with metadata.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
load_options: SQLAlchemy loader options
|
||||
order_by: Column or list of columns to order by
|
||||
page: Page number (1-indexed)
|
||||
items_per_page: Number of items per page
|
||||
|
||||
Returns:
|
||||
Dict with 'data' and 'pagination' keys
|
||||
"""
|
||||
filters = filters or []
|
||||
offset = (page - 1) * items_per_page
|
||||
|
||||
items = await cls.get_multi(
|
||||
session,
|
||||
filters=filters,
|
||||
load_options=load_options,
|
||||
order_by=order_by,
|
||||
limit=items_per_page,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
total_count = await cls.count(session, filters=filters)
|
||||
|
||||
return {
|
||||
"data": items,
|
||||
"pagination": {
|
||||
"total_count": total_count,
|
||||
"items_per_page": items_per_page,
|
||||
"page": page,
|
||||
"has_more": page * items_per_page < total_count,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def CrudFactory(
|
||||
model: type[ModelType],
|
||||
) -> type[AsyncCrud[ModelType]]:
|
||||
"""Create a CRUD class for a specific model.
|
||||
|
||||
Args:
|
||||
model: SQLAlchemy model class
|
||||
|
||||
Returns:
|
||||
AsyncCrud subclass bound to the model
|
||||
|
||||
Example:
|
||||
from fastapi_toolsets.crud import CrudFactory
|
||||
from myapp.models import User, Post
|
||||
|
||||
UserCrud = CrudFactory(User)
|
||||
PostCrud = CrudFactory(Post)
|
||||
|
||||
# Usage
|
||||
user = await UserCrud.get(session, [User.id == 1])
|
||||
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
|
||||
"""
|
||||
cls = type(f"Async{model.__name__}Crud", (AsyncCrud,), {"model": model})
|
||||
return cast(type[AsyncCrud[ModelType]], cls)
|
||||
28
src/fastapi_toolsets/crud/__init__.py
Normal file
28
src/fastapi_toolsets/crud/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Generic async CRUD operations for SQLAlchemy models."""
|
||||
|
||||
from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError
|
||||
from ..schemas import PaginationType
|
||||
from ..types import (
|
||||
FacetFieldType,
|
||||
JoinType,
|
||||
M2MFieldType,
|
||||
OrderByClause,
|
||||
SearchFieldType,
|
||||
)
|
||||
from .factory import AsyncCrud, CrudFactory
|
||||
from .search import SearchConfig, get_searchable_fields
|
||||
|
||||
__all__ = [
|
||||
"AsyncCrud",
|
||||
"CrudFactory",
|
||||
"FacetFieldType",
|
||||
"get_searchable_fields",
|
||||
"InvalidFacetFilterError",
|
||||
"JoinType",
|
||||
"M2MFieldType",
|
||||
"NoSearchableFieldsError",
|
||||
"OrderByClause",
|
||||
"PaginationType",
|
||||
"SearchConfig",
|
||||
"SearchFieldType",
|
||||
]
|
||||
1604
src/fastapi_toolsets/crud/factory.py
Normal file
1604
src/fastapi_toolsets/crud/factory.py
Normal file
File diff suppressed because it is too large
Load Diff
290
src/fastapi_toolsets/crud/search.py
Normal file
290
src/fastapi_toolsets/crud/search.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""Search utilities for AsyncCrud."""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from collections import Counter
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from sqlalchemy import String, and_, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
|
||||
from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError
|
||||
from ..types import FacetFieldType, SearchFieldType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchConfig:
|
||||
"""Advanced search configuration.
|
||||
|
||||
Attributes:
|
||||
query: The search string
|
||||
fields: Fields to search (columns or tuples for relationships)
|
||||
case_sensitive: Case-sensitive search (default: False)
|
||||
match_mode: "any" (OR) or "all" (AND) to combine fields
|
||||
"""
|
||||
|
||||
query: str
|
||||
fields: Sequence[SearchFieldType] | None = None
|
||||
case_sensitive: bool = False
|
||||
match_mode: Literal["any", "all"] = "any"
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def get_searchable_fields(
|
||||
model: type[DeclarativeBase],
|
||||
*,
|
||||
include_relationships: bool = True,
|
||||
max_depth: int = 1,
|
||||
) -> list[SearchFieldType]:
|
||||
"""Auto-detect String fields on a model and its relationships.
|
||||
|
||||
Args:
|
||||
model: SQLAlchemy model class
|
||||
include_relationships: Include fields from many-to-one/one-to-one relationships
|
||||
max_depth: Max depth for relationship traversal (default: 1)
|
||||
|
||||
Returns:
|
||||
List of columns and tuples (relationship, column)
|
||||
"""
|
||||
fields: list[SearchFieldType] = []
|
||||
mapper = model.__mapper__
|
||||
|
||||
# Direct String columns
|
||||
for col in mapper.columns:
|
||||
if isinstance(col.type, String):
|
||||
fields.append(getattr(model, col.key))
|
||||
|
||||
# Relationships (one-to-one, many-to-one only)
|
||||
if include_relationships and max_depth > 0:
|
||||
for rel_name, rel_prop in mapper.relationships.items():
|
||||
if rel_prop.uselist: # Skip collections (one-to-many, many-to-many)
|
||||
continue
|
||||
|
||||
rel_attr = getattr(model, rel_name)
|
||||
related_model = rel_prop.mapper.class_
|
||||
|
||||
for col in related_model.__mapper__.columns:
|
||||
if isinstance(col.type, String):
|
||||
fields.append((rel_attr, getattr(related_model, col.key)))
|
||||
|
||||
return fields
|
||||
|
||||
|
||||
def build_search_filters(
|
||||
model: type[DeclarativeBase],
|
||||
search: str | SearchConfig,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
default_fields: Sequence[SearchFieldType] | None = None,
|
||||
) -> tuple[list["ColumnElement[bool]"], list[InstrumentedAttribute[Any]]]:
|
||||
"""Build SQLAlchemy filter conditions for search.
|
||||
|
||||
Args:
|
||||
model: SQLAlchemy model class
|
||||
search: Search string or SearchConfig
|
||||
search_fields: Fields specified per-call (takes priority)
|
||||
default_fields: Default fields (from ClassVar)
|
||||
|
||||
Returns:
|
||||
Tuple of (filter_conditions, joins_needed)
|
||||
|
||||
Raises:
|
||||
NoSearchableFieldsError: If no searchable field has been configured
|
||||
"""
|
||||
# Normalize input
|
||||
if isinstance(search, str):
|
||||
config = SearchConfig(query=search, fields=search_fields)
|
||||
else:
|
||||
config = (
|
||||
replace(search, fields=search_fields)
|
||||
if search_fields is not None
|
||||
else search
|
||||
)
|
||||
|
||||
if not config.query or not config.query.strip():
|
||||
return [], []
|
||||
|
||||
# Determine which fields to search
|
||||
fields = config.fields or default_fields or get_searchable_fields(model)
|
||||
|
||||
if not fields:
|
||||
raise NoSearchableFieldsError(model)
|
||||
|
||||
query = config.query.strip()
|
||||
filters: list[ColumnElement[bool]] = []
|
||||
joins: list[InstrumentedAttribute[Any]] = []
|
||||
added_joins: set[str] = set()
|
||||
|
||||
for field in fields:
|
||||
if isinstance(field, tuple):
|
||||
# Relationship: (User.role, Role.name) or deeper
|
||||
for rel in field[:-1]:
|
||||
rel_key = str(rel)
|
||||
if rel_key not in added_joins:
|
||||
joins.append(rel)
|
||||
added_joins.add(rel_key)
|
||||
column = field[-1]
|
||||
else:
|
||||
column = field
|
||||
|
||||
# Build the filter (cast to String for non-text columns)
|
||||
column_as_string = column.cast(String)
|
||||
if config.case_sensitive:
|
||||
filters.append(column_as_string.like(f"%{query}%"))
|
||||
else:
|
||||
filters.append(column_as_string.ilike(f"%{query}%"))
|
||||
|
||||
if not filters: # pragma: no cover
|
||||
return [], []
|
||||
|
||||
# Combine based on match_mode
|
||||
if config.match_mode == "any":
|
||||
return [or_(*filters)], joins
|
||||
else:
|
||||
return filters, joins
|
||||
|
||||
|
||||
def facet_keys(facet_fields: Sequence[FacetFieldType]) -> list[str]:
|
||||
"""Return a key for each facet field, disambiguating duplicate column keys.
|
||||
|
||||
Args:
|
||||
facet_fields: Sequence of facet fields — either direct columns or
|
||||
relationship tuples ``(rel, ..., column)``.
|
||||
|
||||
Returns:
|
||||
A list of string keys, one per facet field, in the same order.
|
||||
"""
|
||||
raw: list[tuple[str, str | None]] = []
|
||||
for field in facet_fields:
|
||||
if isinstance(field, tuple):
|
||||
rel = field[-2]
|
||||
column = field[-1]
|
||||
raw.append((column.key, rel.key))
|
||||
else:
|
||||
raw.append((field.key, None))
|
||||
|
||||
counts = Counter(col_key for col_key, _ in raw)
|
||||
keys: list[str] = []
|
||||
for col_key, rel_key in raw:
|
||||
if counts[col_key] > 1 and rel_key is not None:
|
||||
keys.append(f"{rel_key}__{col_key}")
|
||||
else:
|
||||
keys.append(col_key)
|
||||
return keys
|
||||
|
||||
|
||||
async def build_facets(
|
||||
session: "AsyncSession",
|
||||
model: type[DeclarativeBase],
|
||||
facet_fields: Sequence[FacetFieldType],
|
||||
*,
|
||||
base_filters: "list[ColumnElement[bool]] | None" = None,
|
||||
base_joins: list[InstrumentedAttribute[Any]] | None = None,
|
||||
) -> dict[str, list[Any]]:
|
||||
"""Return distinct values for each facet field, respecting current filters.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
model: SQLAlchemy model class
|
||||
facet_fields: Columns or relationship tuples to facet on
|
||||
base_filters: Filter conditions already applied to the main query (search + caller filters)
|
||||
base_joins: Relationship joins already applied to the main query
|
||||
|
||||
Returns:
|
||||
Dict mapping column key to sorted list of distinct non-None values
|
||||
"""
|
||||
existing_join_keys: set[str] = {str(j) for j in (base_joins or [])}
|
||||
|
||||
keys = facet_keys(facet_fields)
|
||||
|
||||
async def _query_facet(field: FacetFieldType, key: str) -> tuple[str, list[Any]]:
|
||||
if isinstance(field, tuple):
|
||||
# Relationship chain: (User.role, Role.name) — last element is the column
|
||||
rels = field[:-1]
|
||||
column = field[-1]
|
||||
else:
|
||||
rels = ()
|
||||
column = field
|
||||
|
||||
q = select(column).select_from(model).distinct()
|
||||
|
||||
# Apply base joins (already done on main query, but needed here independently)
|
||||
for rel in base_joins or []:
|
||||
q = q.outerjoin(rel)
|
||||
|
||||
# Add any extra joins required by this facet field that aren't already in base_joins
|
||||
for rel in rels:
|
||||
if str(rel) not in existing_join_keys:
|
||||
q = q.outerjoin(rel)
|
||||
|
||||
if base_filters:
|
||||
q = q.where(and_(*base_filters))
|
||||
|
||||
q = q.order_by(column)
|
||||
result = await session.execute(q)
|
||||
values = [row[0] for row in result.all() if row[0] is not None]
|
||||
return key, values
|
||||
|
||||
pairs = await asyncio.gather(
|
||||
*[_query_facet(f, k) for f, k in zip(facet_fields, keys)]
|
||||
)
|
||||
return dict(pairs)
|
||||
|
||||
|
||||
def build_filter_by(
|
||||
filter_by: dict[str, Any],
|
||||
facet_fields: Sequence[FacetFieldType],
|
||||
) -> tuple["list[ColumnElement[bool]]", list[InstrumentedAttribute[Any]]]:
|
||||
"""Translate a {column_key: value} dict into SQLAlchemy filter conditions.
|
||||
|
||||
Args:
|
||||
filter_by: Mapping of column key to scalar value or list of values
|
||||
facet_fields: Declared facet fields to validate keys against
|
||||
|
||||
Returns:
|
||||
Tuple of (filter_conditions, joins_needed)
|
||||
|
||||
Raises:
|
||||
InvalidFacetFilterError: If a key in filter_by is not a declared facet field
|
||||
"""
|
||||
index: dict[
|
||||
str, tuple[InstrumentedAttribute[Any], list[InstrumentedAttribute[Any]]]
|
||||
] = {}
|
||||
for key, field in zip(facet_keys(facet_fields), facet_fields):
|
||||
if isinstance(field, tuple):
|
||||
rels = list(field[:-1])
|
||||
column = field[-1]
|
||||
else:
|
||||
rels = []
|
||||
column = field
|
||||
index[key] = (column, rels)
|
||||
|
||||
valid_keys = set(index)
|
||||
filters: list[ColumnElement[bool]] = []
|
||||
joins: list[InstrumentedAttribute[Any]] = []
|
||||
added_join_keys: set[str] = set()
|
||||
|
||||
for key, value in filter_by.items():
|
||||
if key not in index:
|
||||
raise InvalidFacetFilterError(key, valid_keys)
|
||||
|
||||
column, rels = index[key]
|
||||
|
||||
for rel in rels:
|
||||
rel_key = str(rel)
|
||||
if rel_key not in added_join_keys:
|
||||
joins.append(rel)
|
||||
added_join_keys.add(rel_key)
|
||||
|
||||
if isinstance(value, list):
|
||||
filters.append(column.in_(value))
|
||||
else:
|
||||
filters.append(column == value)
|
||||
|
||||
return filters, joins
|
||||
@@ -1,19 +1,26 @@
|
||||
"""Database utilities: sessions, transactions, and locks."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||
from enum import Enum
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from .exceptions import NotFoundError
|
||||
|
||||
__all__ = [
|
||||
"LockMode",
|
||||
"cleanup_tables",
|
||||
"create_database",
|
||||
"create_db_context",
|
||||
"create_db_dependency",
|
||||
"lock_tables",
|
||||
"get_transaction",
|
||||
"lock_tables",
|
||||
"wait_for_row_change",
|
||||
]
|
||||
|
||||
|
||||
@@ -32,6 +39,7 @@ def create_db_dependency(
|
||||
An async generator function usable with FastAPI's Depends()
|
||||
|
||||
Example:
|
||||
```python
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
from fastapi_toolsets.db import create_db_dependency
|
||||
@@ -43,6 +51,7 @@ def create_db_dependency(
|
||||
@app.get("/users")
|
||||
async def list_users(session: AsyncSession = Depends(get_db)):
|
||||
...
|
||||
```
|
||||
"""
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
@@ -69,6 +78,7 @@ def create_db_context(
|
||||
An async context manager function
|
||||
|
||||
Example:
|
||||
```python
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||
from fastapi_toolsets.db import create_db_context
|
||||
|
||||
@@ -80,6 +90,7 @@ def create_db_context(
|
||||
async with get_db_context() as session:
|
||||
user = await UserCrud.get(session, [User.id == 1])
|
||||
...
|
||||
```
|
||||
"""
|
||||
get_db = create_db_dependency(session_maker)
|
||||
return asynccontextmanager(get_db)
|
||||
@@ -101,9 +112,11 @@ async def get_transaction(
|
||||
The session within the transaction context
|
||||
|
||||
Example:
|
||||
```python
|
||||
async with get_transaction(session):
|
||||
session.add(model)
|
||||
# Auto-commits on exit, rolls back on exception
|
||||
```
|
||||
"""
|
||||
if session.in_transaction():
|
||||
async with session.begin_nested():
|
||||
@@ -155,6 +168,7 @@ async def lock_tables(
|
||||
SQLAlchemyError: If lock cannot be acquired within timeout
|
||||
|
||||
Example:
|
||||
```python
|
||||
from fastapi_toolsets.db import lock_tables, LockMode
|
||||
|
||||
async with lock_tables(session, [User, Account]):
|
||||
@@ -166,6 +180,7 @@ async def lock_tables(
|
||||
async with lock_tables(session, [Order], mode=LockMode.EXCLUSIVE):
|
||||
# Exclusive lock - no other transactions can access
|
||||
await process_order(session, order_id)
|
||||
```
|
||||
"""
|
||||
table_names = ",".join(table.__tablename__ for table in tables)
|
||||
|
||||
@@ -173,3 +188,150 @@ async def lock_tables(
|
||||
await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
|
||||
await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
|
||||
yield session
|
||||
|
||||
|
||||
async def create_database(
|
||||
db_name: str,
|
||||
*,
|
||||
server_url: str,
|
||||
) -> None:
|
||||
"""Create a database.
|
||||
|
||||
Connects to *server_url* using ``AUTOCOMMIT`` isolation and issues a
|
||||
``CREATE DATABASE`` statement for *db_name*.
|
||||
|
||||
Args:
|
||||
db_name: Name of the database to create.
|
||||
server_url: URL used for server-level DDL (must point to an existing
|
||||
database on the same server).
|
||||
|
||||
Example:
|
||||
```python
|
||||
from fastapi_toolsets.db import create_database
|
||||
|
||||
SERVER_URL = "postgresql+asyncpg://postgres:postgres@localhost/postgres"
|
||||
await create_database("myapp_test", server_url=SERVER_URL)
|
||||
```
|
||||
"""
|
||||
engine = create_async_engine(server_url, isolation_level="AUTOCOMMIT")
|
||||
try:
|
||||
async with engine.connect() as conn:
|
||||
await conn.execute(text(f"CREATE DATABASE {db_name}"))
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
async def cleanup_tables(
|
||||
session: AsyncSession,
|
||||
base: type[DeclarativeBase],
|
||||
) -> None:
|
||||
"""Truncate all tables for fast between-test cleanup.
|
||||
|
||||
Executes a single ``TRUNCATE … RESTART IDENTITY CASCADE`` statement
|
||||
across every table in *base*'s metadata, which is significantly faster
|
||||
than dropping and re-creating tables between tests.
|
||||
|
||||
This is a no-op when the metadata contains no tables.
|
||||
|
||||
Args:
|
||||
session: An active async database session.
|
||||
base: SQLAlchemy DeclarativeBase class containing model metadata.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@pytest.fixture
|
||||
async def db_session(worker_db_url):
|
||||
async with create_db_session(worker_db_url, Base) as session:
|
||||
yield session
|
||||
await cleanup_tables(session, Base)
|
||||
```
|
||||
"""
|
||||
tables = base.metadata.sorted_tables
|
||||
if not tables:
|
||||
return
|
||||
|
||||
table_names = ", ".join(f'"{t.name}"' for t in tables)
|
||||
await session.execute(text(f"TRUNCATE {table_names} RESTART IDENTITY CASCADE"))
|
||||
await session.commit()
|
||||
|
||||
|
||||
_M = TypeVar("_M", bound=DeclarativeBase)
|
||||
|
||||
|
||||
async def wait_for_row_change(
|
||||
session: AsyncSession,
|
||||
model: type[_M],
|
||||
pk_value: Any,
|
||||
*,
|
||||
columns: list[str] | None = None,
|
||||
interval: float = 0.5,
|
||||
timeout: float | None = None,
|
||||
) -> _M:
|
||||
"""Poll a database row until a change is detected.
|
||||
|
||||
Queries the row every ``interval`` seconds and returns the model instance
|
||||
once a change is detected in any column (or only the specified ``columns``).
|
||||
|
||||
Args:
|
||||
session: AsyncSession instance
|
||||
model: SQLAlchemy model class
|
||||
pk_value: Primary key value of the row to watch
|
||||
columns: Optional list of column names to watch. If None, all columns
|
||||
are watched.
|
||||
interval: Polling interval in seconds (default: 0.5)
|
||||
timeout: Maximum time to wait in seconds. None means wait forever.
|
||||
|
||||
Returns:
|
||||
The refreshed model instance with updated values
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the row does not exist or is deleted during polling
|
||||
TimeoutError: If timeout expires before a change is detected
|
||||
|
||||
Example:
|
||||
```python
|
||||
from fastapi_toolsets.db import wait_for_row_change
|
||||
|
||||
# Wait for any column to change
|
||||
updated = await wait_for_row_change(session, User, user_id)
|
||||
|
||||
# Watch specific columns with a timeout
|
||||
updated = await wait_for_row_change(
|
||||
session, User, user_id,
|
||||
columns=["status", "email"],
|
||||
interval=1.0,
|
||||
timeout=30.0,
|
||||
)
|
||||
```
|
||||
"""
|
||||
instance = await session.get(model, pk_value)
|
||||
if instance is None:
|
||||
raise NotFoundError(f"{model.__name__} with pk={pk_value!r} not found")
|
||||
|
||||
if columns is not None:
|
||||
watch_cols = columns
|
||||
else:
|
||||
watch_cols = [attr.key for attr in model.__mapper__.column_attrs]
|
||||
|
||||
initial = {col: getattr(instance, col) for col in watch_cols}
|
||||
|
||||
elapsed = 0.0
|
||||
while True:
|
||||
await asyncio.sleep(interval)
|
||||
elapsed += interval
|
||||
|
||||
if timeout is not None and elapsed >= timeout:
|
||||
raise TimeoutError(
|
||||
f"No change detected on {model.__name__} "
|
||||
f"with pk={pk_value!r} within {timeout}s"
|
||||
)
|
||||
|
||||
session.expunge(instance)
|
||||
instance = await session.get(model, pk_value)
|
||||
|
||||
if instance is None:
|
||||
raise NotFoundError(f"{model.__name__} with pk={pk_value!r} was deleted")
|
||||
|
||||
current = {col: getattr(instance, col) for col in watch_cols}
|
||||
if current != initial:
|
||||
return instance
|
||||
|
||||
155
src/fastapi_toolsets/dependencies.py
Normal file
155
src/fastapi_toolsets/dependencies.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""Dependency factories for FastAPI routes."""
|
||||
|
||||
import inspect
|
||||
import typing
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi.params import Depends as DependsClass
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from .crud import CrudFactory
|
||||
from .types import ModelType, SessionDependency
|
||||
|
||||
__all__ = ["BodyDependency", "PathDependency"]
|
||||
|
||||
|
||||
def _unwrap_session_dep(session_dep: SessionDependency) -> Callable[..., Any]:
|
||||
"""Extract the plain callable from ``Annotated[AsyncSession, Depends(fn)]`` if needed."""
|
||||
if typing.get_origin(session_dep) is typing.Annotated:
|
||||
for arg in typing.get_args(session_dep)[1:]:
|
||||
if isinstance(arg, DependsClass):
|
||||
return arg.dependency
|
||||
return session_dep
|
||||
|
||||
|
||||
def PathDependency(
|
||||
model: type[ModelType],
|
||||
field: Any,
|
||||
*,
|
||||
session_dep: SessionDependency,
|
||||
param_name: str | None = None,
|
||||
) -> ModelType:
|
||||
"""Create a dependency that fetches a DB object from a path parameter.
|
||||
|
||||
Args:
|
||||
model: SQLAlchemy model class
|
||||
field: Model field to filter by (e.g., User.id)
|
||||
session_dep: Session dependency function (e.g., get_db)
|
||||
param_name: Path parameter name (defaults to model_field, e.g., user_id)
|
||||
|
||||
Returns:
|
||||
A Depends() instance that resolves to the model instance
|
||||
|
||||
Raises:
|
||||
NotFoundError: If no matching record is found
|
||||
|
||||
Example:
|
||||
```python
|
||||
UserDep = PathDependency(User, User.id, session_dep=get_db)
|
||||
|
||||
@router.get("/user/{id}")
|
||||
async def get(
|
||||
user: User = UserDep,
|
||||
): ...
|
||||
```
|
||||
"""
|
||||
session_callable = _unwrap_session_dep(session_dep)
|
||||
crud = CrudFactory(model)
|
||||
name = (
|
||||
param_name
|
||||
if param_name is not None
|
||||
else "{}_{}".format(model.__name__.lower(), field.key)
|
||||
)
|
||||
python_type = field.type.python_type
|
||||
|
||||
async def dependency(
|
||||
session: AsyncSession = Depends(session_callable), **kwargs: Any
|
||||
) -> ModelType:
|
||||
value = kwargs[name]
|
||||
return await crud.get(session, filters=[field == value])
|
||||
|
||||
setattr(
|
||||
dependency,
|
||||
"__signature__",
|
||||
inspect.Signature(
|
||||
parameters=[
|
||||
inspect.Parameter(
|
||||
name, inspect.Parameter.KEYWORD_ONLY, annotation=python_type
|
||||
),
|
||||
inspect.Parameter(
|
||||
"session",
|
||||
inspect.Parameter.KEYWORD_ONLY,
|
||||
annotation=AsyncSession,
|
||||
default=Depends(session_callable),
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
return cast(ModelType, Depends(cast(Callable[..., ModelType], dependency)))
|
||||
|
||||
|
||||
def BodyDependency(
|
||||
model: type[ModelType],
|
||||
field: Any,
|
||||
*,
|
||||
session_dep: SessionDependency,
|
||||
body_field: str,
|
||||
) -> ModelType:
|
||||
"""Create a dependency that fetches a DB object from a body field.
|
||||
|
||||
Args:
|
||||
model: SQLAlchemy model class
|
||||
field: Model field to filter by (e.g., User.id)
|
||||
session_dep: Session dependency function (e.g., get_db)
|
||||
body_field: Name of the field in the request body
|
||||
|
||||
Returns:
|
||||
A Depends() instance that resolves to the model instance
|
||||
|
||||
Raises:
|
||||
NotFoundError: If no matching record is found
|
||||
|
||||
Example:
|
||||
```python
|
||||
UserDep = BodyDependency(
|
||||
User, User.ctfd_id, session_dep=get_db, body_field="user_id"
|
||||
)
|
||||
|
||||
@router.post("/assign")
|
||||
async def assign(
|
||||
user: User = UserDep,
|
||||
): ...
|
||||
```
|
||||
"""
|
||||
session_callable = _unwrap_session_dep(session_dep)
|
||||
crud = CrudFactory(model)
|
||||
python_type = field.type.python_type
|
||||
|
||||
async def dependency(
|
||||
session: AsyncSession = Depends(session_callable), **kwargs: Any
|
||||
) -> ModelType:
|
||||
value = kwargs[body_field]
|
||||
return await crud.get(session, filters=[field == value])
|
||||
|
||||
setattr(
|
||||
dependency,
|
||||
"__signature__",
|
||||
inspect.Signature(
|
||||
parameters=[
|
||||
inspect.Parameter(
|
||||
body_field, inspect.Parameter.KEYWORD_ONLY, annotation=python_type
|
||||
),
|
||||
inspect.Parameter(
|
||||
"session",
|
||||
inspect.Parameter.KEYWORD_ONLY,
|
||||
annotation=AsyncSession,
|
||||
default=Depends(session_callable),
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
return cast(ModelType, Depends(cast(Callable[..., ModelType], dependency)))
|
||||
@@ -1,7 +1,13 @@
|
||||
"""Standardized API exceptions and error response handlers."""
|
||||
|
||||
from .exceptions import (
|
||||
ApiError,
|
||||
ApiException,
|
||||
ConflictError,
|
||||
ForbiddenError,
|
||||
InvalidFacetFilterError,
|
||||
InvalidOrderFieldError,
|
||||
NoSearchableFieldsError,
|
||||
NotFoundError,
|
||||
UnauthorizedError,
|
||||
generate_error_responses,
|
||||
@@ -9,11 +15,15 @@ from .exceptions import (
|
||||
from .handler import init_exceptions_handlers
|
||||
|
||||
__all__ = [
|
||||
"init_exceptions_handlers",
|
||||
"generate_error_responses",
|
||||
"ApiError",
|
||||
"ApiException",
|
||||
"ConflictError",
|
||||
"ForbiddenError",
|
||||
"generate_error_responses",
|
||||
"init_exceptions_handlers",
|
||||
"InvalidFacetFilterError",
|
||||
"InvalidOrderFieldError",
|
||||
"NoSearchableFieldsError",
|
||||
"NotFoundError",
|
||||
"UnauthorizedError",
|
||||
]
|
||||
|
||||
@@ -6,30 +6,46 @@ from ..schemas import ApiError, ErrorResponse, ResponseStatus
|
||||
|
||||
|
||||
class ApiException(Exception):
|
||||
"""Base exception for API errors with structured response.
|
||||
|
||||
Subclass this to create custom API exceptions with consistent error format.
|
||||
The exception handler will use api_error to generate the response.
|
||||
|
||||
Example:
|
||||
class CustomError(ApiException):
|
||||
api_error = ApiError(
|
||||
code=400,
|
||||
msg="Bad Request",
|
||||
desc="The request was invalid.",
|
||||
err_code="CUSTOM-400",
|
||||
)
|
||||
"""
|
||||
"""Base exception for API errors with structured response."""
|
||||
|
||||
api_error: ClassVar[ApiError]
|
||||
|
||||
def __init__(self, detail: str | None = None):
|
||||
def __init_subclass__(cls, abstract: bool = False, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
if not abstract and not hasattr(cls, "api_error"):
|
||||
raise TypeError(
|
||||
f"{cls.__name__} must define an 'api_error' class attribute. "
|
||||
"Pass abstract=True when creating intermediate base classes."
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
detail: str | None = None,
|
||||
*,
|
||||
desc: str | None = None,
|
||||
data: Any = None,
|
||||
) -> None:
|
||||
"""Initialize the exception.
|
||||
|
||||
Args:
|
||||
detail: Optional override for the error message
|
||||
detail: Optional human-readable message
|
||||
desc: Optional per-instance override for the ``description`` field
|
||||
in the HTTP response body.
|
||||
data: Optional per-instance override for the ``data`` field in the
|
||||
HTTP response body.
|
||||
"""
|
||||
super().__init__(detail or self.api_error.msg)
|
||||
updates: dict[str, Any] = {}
|
||||
if detail is not None:
|
||||
updates["msg"] = detail
|
||||
if desc is not None:
|
||||
updates["desc"] = desc
|
||||
if data is not None:
|
||||
updates["data"] = data
|
||||
if updates:
|
||||
object.__setattr__(
|
||||
self, "api_error", self.__class__.api_error.model_copy(update=updates)
|
||||
)
|
||||
super().__init__(self.api_error.msg)
|
||||
|
||||
|
||||
class UnauthorizedError(ApiException):
|
||||
@@ -76,90 +92,120 @@ class ConflictError(ApiException):
|
||||
)
|
||||
|
||||
|
||||
class InsufficientRolesError(ForbiddenError):
|
||||
"""User does not have the required roles."""
|
||||
class NoSearchableFieldsError(ApiException):
|
||||
"""Raised when search is requested but no searchable fields are available."""
|
||||
|
||||
api_error = ApiError(
|
||||
code=403,
|
||||
msg="Insufficient Roles",
|
||||
desc="You do not have the required roles to access this resource.",
|
||||
err_code="RBAC-403",
|
||||
code=400,
|
||||
msg="No Searchable Fields",
|
||||
desc="No searchable fields configured for this resource.",
|
||||
err_code="SEARCH-400",
|
||||
)
|
||||
|
||||
def __init__(self, required_roles: list[str], user_roles: set[str] | None = None):
|
||||
self.required_roles = required_roles
|
||||
self.user_roles = user_roles
|
||||
def __init__(self, model: type) -> None:
|
||||
"""Initialize the exception.
|
||||
|
||||
desc = f"Required roles: {', '.join(required_roles)}"
|
||||
if user_roles is not None:
|
||||
desc += f". User has: {', '.join(user_roles) if user_roles else 'no roles'}"
|
||||
|
||||
super().__init__(desc)
|
||||
Args:
|
||||
model: The model class that has no searchable fields configured.
|
||||
"""
|
||||
self.model = model
|
||||
super().__init__(
|
||||
desc=(
|
||||
f"No searchable fields found for model '{model.__name__}'. "
|
||||
"Provide 'search_fields' parameter or set 'searchable_fields' on the CRUD class."
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class UserNotFoundError(NotFoundError):
|
||||
"""User was not found."""
|
||||
class InvalidFacetFilterError(ApiException):
|
||||
"""Raised when filter_by contains a key not declared in facet_fields."""
|
||||
|
||||
api_error = ApiError(
|
||||
code=404,
|
||||
msg="User Not Found",
|
||||
desc="The requested user was not found.",
|
||||
err_code="USER-404",
|
||||
code=400,
|
||||
msg="Invalid Facet Filter",
|
||||
desc="One or more filter_by keys are not declared as facet fields.",
|
||||
err_code="FACET-400",
|
||||
)
|
||||
|
||||
def __init__(self, key: str, valid_keys: set[str]) -> None:
|
||||
"""Initialize the exception.
|
||||
|
||||
class RoleNotFoundError(NotFoundError):
|
||||
"""Role was not found."""
|
||||
Args:
|
||||
key: The unknown filter key provided by the caller.
|
||||
valid_keys: Set of valid keys derived from the declared facet_fields.
|
||||
"""
|
||||
self.key = key
|
||||
self.valid_keys = valid_keys
|
||||
super().__init__(
|
||||
desc=(
|
||||
f"'{key}' is not a declared facet field. "
|
||||
f"Valid keys: {sorted(valid_keys) or 'none — set facet_fields on the CRUD class'}."
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class InvalidOrderFieldError(ApiException):
|
||||
"""Raised when order_by contains a field not in the allowed order fields."""
|
||||
|
||||
api_error = ApiError(
|
||||
code=404,
|
||||
msg="Role Not Found",
|
||||
desc="The requested role was not found.",
|
||||
err_code="ROLE-404",
|
||||
code=422,
|
||||
msg="Invalid Order Field",
|
||||
desc="The requested order field is not allowed for this resource.",
|
||||
err_code="SORT-422",
|
||||
)
|
||||
|
||||
def __init__(self, field: str, valid_fields: list[str]) -> None:
|
||||
"""Initialize the exception.
|
||||
|
||||
Args:
|
||||
field: The unknown order field provided by the caller.
|
||||
valid_fields: List of valid field names.
|
||||
"""
|
||||
self.field = field
|
||||
self.valid_fields = valid_fields
|
||||
super().__init__(
|
||||
desc=f"'{field}' is not an allowed order field. Valid fields: {valid_fields}."
|
||||
)
|
||||
|
||||
|
||||
def generate_error_responses(
|
||||
*errors: type[ApiException],
|
||||
) -> dict[int | str, dict[str, Any]]:
|
||||
"""Generate OpenAPI response documentation for exceptions.
|
||||
|
||||
Use this to document possible error responses for an endpoint.
|
||||
|
||||
Args:
|
||||
*errors: Exception classes that inherit from ApiException
|
||||
*errors: Exception classes that inherit from ApiException.
|
||||
|
||||
Returns:
|
||||
Dict suitable for FastAPI's responses parameter
|
||||
|
||||
Example:
|
||||
from fastapi_toolsets.exceptions import generate_error_responses, UnauthorizedError, ForbiddenError
|
||||
|
||||
@app.get(
|
||||
"/admin",
|
||||
responses=generate_error_responses(UnauthorizedError, ForbiddenError)
|
||||
)
|
||||
async def admin_endpoint():
|
||||
...
|
||||
Dict suitable for FastAPI's ``responses`` parameter.
|
||||
"""
|
||||
responses: dict[int | str, dict[str, Any]] = {}
|
||||
|
||||
for error in errors:
|
||||
api_error = error.api_error
|
||||
code = api_error.code
|
||||
|
||||
responses[api_error.code] = {
|
||||
"model": ErrorResponse,
|
||||
"description": api_error.msg,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"example": {
|
||||
"data": None,
|
||||
"status": ResponseStatus.FAIL.value,
|
||||
"message": api_error.msg,
|
||||
"description": api_error.desc,
|
||||
"error_code": api_error.err_code,
|
||||
if code not in responses:
|
||||
responses[code] = {
|
||||
"model": ErrorResponse,
|
||||
"description": api_error.msg,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"examples": {},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
responses[code]["content"]["application/json"]["examples"][
|
||||
api_error.err_code
|
||||
] = {
|
||||
"summary": api_error.msg,
|
||||
"value": {
|
||||
"data": api_error.data,
|
||||
"status": ResponseStatus.FAIL.value,
|
||||
"message": api_error.msg,
|
||||
"description": api_error.desc,
|
||||
"error_code": api_error.err_code,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -1,50 +1,69 @@
|
||||
"""Exception handlers for FastAPI applications."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, Request, Response, status
|
||||
from fastapi.exceptions import RequestValidationError, ResponseValidationError
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.exceptions import (
|
||||
HTTPException,
|
||||
RequestValidationError,
|
||||
ResponseValidationError,
|
||||
)
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from ..schemas import ResponseStatus
|
||||
from ..schemas import ErrorResponse, ResponseStatus
|
||||
from .exceptions import ApiException
|
||||
|
||||
_VALIDATION_LOCATION_PARAMS: frozenset[str] = frozenset(
|
||||
{"body", "query", "path", "header", "cookie"}
|
||||
)
|
||||
|
||||
|
||||
def init_exceptions_handlers(app: FastAPI) -> FastAPI:
|
||||
"""Register exception handlers and custom OpenAPI schema on a FastAPI app.
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance.
|
||||
|
||||
Returns:
|
||||
The same FastAPI instance (for chaining).
|
||||
"""
|
||||
_register_exception_handlers(app)
|
||||
app.openapi = lambda: _custom_openapi(app) # type: ignore[method-assign]
|
||||
_original_openapi = app.openapi
|
||||
app.openapi = lambda: _patched_openapi(app, _original_openapi) # type: ignore[method-assign]
|
||||
return app
|
||||
|
||||
|
||||
def _register_exception_handlers(app: FastAPI) -> None:
|
||||
"""Register all exception handlers on a FastAPI application.
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance
|
||||
|
||||
Example:
|
||||
from fastapi import FastAPI
|
||||
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
||||
|
||||
app = FastAPI()
|
||||
init_exceptions_handlers(app)
|
||||
"""
|
||||
"""Register all exception handlers on a FastAPI application."""
|
||||
|
||||
@app.exception_handler(ApiException)
|
||||
async def api_exception_handler(request: Request, exc: ApiException) -> Response:
|
||||
"""Handle custom API exceptions with structured response."""
|
||||
api_error = exc.api_error
|
||||
|
||||
error_response = ErrorResponse(
|
||||
data=api_error.data,
|
||||
message=api_error.msg,
|
||||
description=api_error.desc,
|
||||
error_code=api_error.err_code,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=api_error.code,
|
||||
content={
|
||||
"data": None,
|
||||
"status": ResponseStatus.FAIL.value,
|
||||
"message": api_error.msg,
|
||||
"description": api_error.desc,
|
||||
"error_code": api_error.err_code,
|
||||
},
|
||||
content=error_response.model_dump(),
|
||||
)
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request: Request, exc: HTTPException) -> Response:
|
||||
"""Handle Starlette/FastAPI HTTPException with a consistent error format."""
|
||||
detail = exc.detail if isinstance(exc.detail, str) else "HTTP Error"
|
||||
error_response = ErrorResponse(
|
||||
message=detail,
|
||||
error_code=f"HTTP-{exc.status_code}",
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=error_response.model_dump(),
|
||||
headers=getattr(exc, "headers", None),
|
||||
)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
@@ -64,15 +83,14 @@ def _register_exception_handlers(app: FastAPI) -> None:
|
||||
@app.exception_handler(Exception)
|
||||
async def generic_exception_handler(request: Request, exc: Exception) -> Response:
|
||||
"""Handle all unhandled exceptions with a generic 500 response."""
|
||||
error_response = ErrorResponse(
|
||||
message="Internal Server Error",
|
||||
description="An unexpected error occurred. Please try again later.",
|
||||
error_code="SERVER-500",
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={
|
||||
"data": None,
|
||||
"status": ResponseStatus.FAIL.value,
|
||||
"message": "Internal Server Error",
|
||||
"description": "An unexpected error occurred. Please try again later.",
|
||||
"error_code": "SERVER-500",
|
||||
},
|
||||
content=error_response.model_dump(),
|
||||
)
|
||||
|
||||
|
||||
@@ -84,11 +102,10 @@ def _format_validation_error(
|
||||
formatted_errors = []
|
||||
|
||||
for error in errors:
|
||||
field_path = ".".join(
|
||||
str(loc)
|
||||
for loc in error["loc"]
|
||||
if loc not in ("body", "query", "path", "header", "cookie")
|
||||
)
|
||||
locs = error["loc"]
|
||||
if locs and locs[0] in _VALIDATION_LOCATION_PARAMS:
|
||||
locs = locs[1:]
|
||||
field_path = ".".join(str(loc) for loc in locs)
|
||||
formatted_errors.append(
|
||||
{
|
||||
"field": field_path or "root",
|
||||
@@ -97,46 +114,35 @@ def _format_validation_error(
|
||||
}
|
||||
)
|
||||
|
||||
error_response = ErrorResponse(
|
||||
data={"errors": formatted_errors},
|
||||
message="Validation Error",
|
||||
description=f"{len(formatted_errors)} validation error(s) detected",
|
||||
error_code="VAL-422",
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
content={
|
||||
"data": {"errors": formatted_errors},
|
||||
"status": ResponseStatus.FAIL.value,
|
||||
"message": "Validation Error",
|
||||
"description": f"{len(formatted_errors)} validation error(s) detected",
|
||||
"error_code": "VAL-422",
|
||||
},
|
||||
content=error_response.model_dump(),
|
||||
)
|
||||
|
||||
|
||||
def _custom_openapi(app: FastAPI) -> dict[str, Any]:
|
||||
"""Generate custom OpenAPI schema with standardized error format.
|
||||
|
||||
Replaces default 422 validation error responses with the custom format.
|
||||
def _patched_openapi(
|
||||
app: FastAPI, original_openapi: Callable[[], dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
"""Generate the OpenAPI schema and replace default 422 responses.
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance
|
||||
app: FastAPI application instance.
|
||||
original_openapi: The previous ``app.openapi`` callable to delegate to.
|
||||
|
||||
Returns:
|
||||
OpenAPI schema dict
|
||||
|
||||
Example:
|
||||
from fastapi import FastAPI
|
||||
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
||||
|
||||
app = FastAPI()
|
||||
init_exceptions_handlers(app) # Automatically sets custom OpenAPI
|
||||
Patched OpenAPI schema dict.
|
||||
"""
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
openapi_version=app.openapi_version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
openapi_schema = original_openapi()
|
||||
|
||||
for path_data in openapi_schema.get("paths", {}).values():
|
||||
for operation in path_data.values():
|
||||
@@ -146,20 +152,25 @@ def _custom_openapi(app: FastAPI) -> dict[str, Any]:
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"example": {
|
||||
"data": {
|
||||
"errors": [
|
||||
{
|
||||
"field": "field_name",
|
||||
"message": "value is not valid",
|
||||
"type": "value_error",
|
||||
}
|
||||
]
|
||||
},
|
||||
"status": ResponseStatus.FAIL.value,
|
||||
"message": "Validation Error",
|
||||
"description": "1 validation error(s) detected",
|
||||
"error_code": "VAL-422",
|
||||
"examples": {
|
||||
"VAL-422": {
|
||||
"summary": "Validation Error",
|
||||
"value": {
|
||||
"data": {
|
||||
"errors": [
|
||||
{
|
||||
"field": "field_name",
|
||||
"message": "value is not valid",
|
||||
"type": "value_error",
|
||||
}
|
||||
]
|
||||
},
|
||||
"status": ResponseStatus.FAIL.value,
|
||||
"message": "Validation Error",
|
||||
"description": "1 validation error(s) detected",
|
||||
"error_code": "VAL-422",
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Fixture system for seeding databases with dependency resolution."""
|
||||
|
||||
from .enum import LoadStrategy
|
||||
from .registry import Context, FixtureRegistry
|
||||
from .utils import get_obj_by_attr, load_fixtures, load_fixtures_by_context
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Enums for fixture loading strategies and contexts."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
"""Fixture system with dependency management and context support."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from ..logger import get_logger
|
||||
from .enum import Context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -26,6 +26,7 @@ class FixtureRegistry:
|
||||
"""Registry for managing fixtures with dependencies.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from fastapi_toolsets.fixtures import FixtureRegistry, Context
|
||||
|
||||
fixtures = FixtureRegistry()
|
||||
@@ -48,10 +49,19 @@ class FixtureRegistry:
|
||||
return [
|
||||
Post(id=1, title="Test", user_id=1),
|
||||
]
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
contexts: list[str | Context] | None = None,
|
||||
) -> None:
|
||||
self._fixtures: dict[str, Fixture] = {}
|
||||
self._default_contexts: list[str] | None = (
|
||||
[c.value if isinstance(c, Context) else c for c in contexts]
|
||||
if contexts
|
||||
else None
|
||||
)
|
||||
|
||||
def register(
|
||||
self,
|
||||
@@ -72,6 +82,7 @@ class FixtureRegistry:
|
||||
contexts: List of contexts this fixture belongs to
|
||||
|
||||
Example:
|
||||
```python
|
||||
@fixtures.register
|
||||
def roles():
|
||||
return [Role(id=1, name="admin")]
|
||||
@@ -79,16 +90,21 @@ class FixtureRegistry:
|
||||
@fixtures.register(depends_on=["roles"], contexts=[Context.TESTING])
|
||||
def test_users():
|
||||
return [User(id=1, username="test", role_id=1)]
|
||||
```
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
fn: Callable[[], Sequence[DeclarativeBase]],
|
||||
) -> Callable[[], Sequence[DeclarativeBase]]:
|
||||
fixture_name = name or cast(Any, fn).__name__
|
||||
fixture_contexts = [
|
||||
c.value if isinstance(c, Context) else c
|
||||
for c in (contexts or [Context.BASE])
|
||||
]
|
||||
if contexts is not None:
|
||||
fixture_contexts = [
|
||||
c.value if isinstance(c, Context) else c for c in contexts
|
||||
]
|
||||
elif self._default_contexts is not None:
|
||||
fixture_contexts = self._default_contexts
|
||||
else:
|
||||
fixture_contexts = [Context.BASE.value]
|
||||
|
||||
self._fixtures[fixture_name] = Fixture(
|
||||
name=fixture_name,
|
||||
@@ -102,6 +118,34 @@ class FixtureRegistry:
|
||||
return decorator(func)
|
||||
return decorator
|
||||
|
||||
def include_registry(self, registry: "FixtureRegistry") -> None:
|
||||
"""Include another `FixtureRegistry` in the same current `FixtureRegistry`.
|
||||
|
||||
Args:
|
||||
registry: The `FixtureRegistry` to include
|
||||
|
||||
Raises:
|
||||
ValueError: If a fixture name already exists in the current registry
|
||||
|
||||
Example:
|
||||
```python
|
||||
registry = FixtureRegistry()
|
||||
dev_registry = FixtureRegistry()
|
||||
|
||||
@dev_registry.register
|
||||
def dev_data():
|
||||
return [...]
|
||||
|
||||
registry.include_registry(registry=dev_registry)
|
||||
```
|
||||
"""
|
||||
for name, fixture in registry._fixtures.items():
|
||||
if name in self._fixtures:
|
||||
raise ValueError(
|
||||
f"Fixture '{name}' already exists in the current registry"
|
||||
)
|
||||
self._fixtures[name] = fixture
|
||||
|
||||
def get(self, name: str) -> Fixture:
|
||||
"""Get a fixture by name."""
|
||||
if name not in self._fixtures:
|
||||
|
||||
@@ -1,91 +1,18 @@
|
||||
import logging
|
||||
"""Fixture loading utilities for database seeding."""
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from ..db import get_transaction
|
||||
from ..logger import get_logger
|
||||
from ..types import ModelType
|
||||
from .enum import LoadStrategy
|
||||
from .registry import Context, FixtureRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T", bound=DeclarativeBase)
|
||||
|
||||
|
||||
def get_obj_by_attr(
|
||||
fixtures: Callable[[], Sequence[T]], attr_name: str, value: Any
|
||||
) -> T:
|
||||
"""Get a SQLAlchemy model instance by matching an attribute value.
|
||||
|
||||
Args:
|
||||
fixtures: A fixture function registered via ``@registry.register``
|
||||
that returns a sequence of SQLAlchemy model instances.
|
||||
attr_name: Name of the attribute to match against.
|
||||
value: Value to match.
|
||||
|
||||
Returns:
|
||||
The first model instance where the attribute matches the given value.
|
||||
|
||||
Raises:
|
||||
StopIteration: If no matching object is found.
|
||||
"""
|
||||
return next(obj for obj in fixtures() if getattr(obj, attr_name) == value)
|
||||
|
||||
|
||||
async def load_fixtures(
|
||||
session: AsyncSession,
|
||||
registry: FixtureRegistry,
|
||||
*names: str,
|
||||
strategy: LoadStrategy = LoadStrategy.MERGE,
|
||||
) -> dict[str, list[DeclarativeBase]]:
|
||||
"""Load specific fixtures by name with dependencies.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
registry: Fixture registry
|
||||
*names: Fixture names to load (dependencies auto-resolved)
|
||||
strategy: How to handle existing records
|
||||
|
||||
Returns:
|
||||
Dict mapping fixture names to loaded instances
|
||||
|
||||
Example:
|
||||
# Loads 'roles' first (dependency), then 'users'
|
||||
result = await load_fixtures(session, fixtures, "users")
|
||||
print(result["users"]) # [User(...), ...]
|
||||
"""
|
||||
ordered = registry.resolve_dependencies(*names)
|
||||
return await _load_ordered(session, registry, ordered, strategy)
|
||||
|
||||
|
||||
async def load_fixtures_by_context(
|
||||
session: AsyncSession,
|
||||
registry: FixtureRegistry,
|
||||
*contexts: str | Context,
|
||||
strategy: LoadStrategy = LoadStrategy.MERGE,
|
||||
) -> dict[str, list[DeclarativeBase]]:
|
||||
"""Load all fixtures for specific contexts.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
registry: Fixture registry
|
||||
*contexts: Contexts to load (e.g., Context.BASE, Context.TESTING)
|
||||
strategy: How to handle existing records
|
||||
|
||||
Returns:
|
||||
Dict mapping fixture names to loaded instances
|
||||
|
||||
Example:
|
||||
# Load base + testing fixtures
|
||||
await load_fixtures_by_context(
|
||||
session, fixtures,
|
||||
Context.BASE, Context.TESTING
|
||||
)
|
||||
"""
|
||||
ordered = registry.resolve_context_dependencies(*contexts)
|
||||
return await _load_ordered(session, registry, ordered, strategy)
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def _load_ordered(
|
||||
@@ -118,7 +45,7 @@ async def _load_ordered(
|
||||
merged = await session.merge(instance)
|
||||
loaded.append(merged)
|
||||
|
||||
elif strategy == LoadStrategy.SKIP_EXISTING:
|
||||
else: # LoadStrategy.SKIP_EXISTING
|
||||
pk = _get_primary_key(instance)
|
||||
if pk is not None:
|
||||
existing = await session.get(type(instance), pk)
|
||||
@@ -147,3 +74,70 @@ def _get_primary_key(instance: DeclarativeBase) -> Any | None:
|
||||
if all(v is not None for v in pk_values):
|
||||
return pk_values
|
||||
return None
|
||||
|
||||
|
||||
def get_obj_by_attr(
|
||||
fixtures: Callable[[], Sequence[ModelType]], attr_name: str, value: Any
|
||||
) -> ModelType:
|
||||
"""Get a SQLAlchemy model instance by matching an attribute value.
|
||||
|
||||
Args:
|
||||
fixtures: A fixture function registered via ``@registry.register``
|
||||
that returns a sequence of SQLAlchemy model instances.
|
||||
attr_name: Name of the attribute to match against.
|
||||
value: Value to match.
|
||||
|
||||
Returns:
|
||||
The first model instance where the attribute matches the given value.
|
||||
|
||||
Raises:
|
||||
StopIteration: If no matching object is found in the fixture group.
|
||||
"""
|
||||
try:
|
||||
return next(obj for obj in fixtures() if getattr(obj, attr_name) == value)
|
||||
except StopIteration:
|
||||
raise StopIteration(
|
||||
f"No object with {attr_name}={value} found in fixture '{getattr(fixtures, '__name__', repr(fixtures))}'"
|
||||
) from None
|
||||
|
||||
|
||||
async def load_fixtures(
|
||||
session: AsyncSession,
|
||||
registry: FixtureRegistry,
|
||||
*names: str,
|
||||
strategy: LoadStrategy = LoadStrategy.MERGE,
|
||||
) -> dict[str, list[DeclarativeBase]]:
|
||||
"""Load specific fixtures by name with dependencies.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
registry: Fixture registry
|
||||
*names: Fixture names to load (dependencies auto-resolved)
|
||||
strategy: How to handle existing records
|
||||
|
||||
Returns:
|
||||
Dict mapping fixture names to loaded instances
|
||||
"""
|
||||
ordered = registry.resolve_dependencies(*names)
|
||||
return await _load_ordered(session, registry, ordered, strategy)
|
||||
|
||||
|
||||
async def load_fixtures_by_context(
|
||||
session: AsyncSession,
|
||||
registry: FixtureRegistry,
|
||||
*contexts: str | Context,
|
||||
strategy: LoadStrategy = LoadStrategy.MERGE,
|
||||
) -> dict[str, list[DeclarativeBase]]:
|
||||
"""Load all fixtures for specific contexts.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
registry: Fixture registry
|
||||
*contexts: Contexts to load (e.g., Context.BASE, Context.TESTING)
|
||||
strategy: How to handle existing records
|
||||
|
||||
Returns:
|
||||
Dict mapping fixture names to loaded instances
|
||||
"""
|
||||
ordered = registry.resolve_context_dependencies(*contexts)
|
||||
return await _load_ordered(session, registry, ordered, strategy)
|
||||
|
||||
98
src/fastapi_toolsets/logger.py
Normal file
98
src/fastapi_toolsets/logger.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Logging configuration for FastAPI applications and CLI tools."""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import Literal
|
||||
|
||||
__all__ = ["LogLevel", "configure_logging", "get_logger"]
|
||||
|
||||
DEFAULT_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
UVICORN_LOGGERS = ("uvicorn", "uvicorn.access", "uvicorn.error")
|
||||
|
||||
LogLevel = Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
||||
|
||||
|
||||
def configure_logging(
|
||||
level: LogLevel | int = "INFO",
|
||||
fmt: str = DEFAULT_FORMAT,
|
||||
logger_name: str | None = None,
|
||||
) -> logging.Logger:
|
||||
"""Configure logging with a stdout handler and consistent format.
|
||||
|
||||
Sets up a :class:`~logging.StreamHandler` writing to stdout with the
|
||||
given format and level. Also configures the uvicorn loggers so that
|
||||
FastAPI access logs use the same format.
|
||||
|
||||
Calling this function multiple times is safe -- existing handlers are
|
||||
replaced rather than duplicated.
|
||||
|
||||
Args:
|
||||
level: Log level (e.g. ``"DEBUG"``, ``"INFO"``, or ``logging.DEBUG``).
|
||||
fmt: Log format string. Defaults to
|
||||
``"%(asctime)s - %(name)s - %(levelname)s - %(message)s"``.
|
||||
logger_name: Logger name to configure. ``None`` (the default)
|
||||
configures the root logger so all loggers inherit the settings.
|
||||
|
||||
Returns:
|
||||
The configured Logger instance.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from fastapi_toolsets.logger import configure_logging
|
||||
|
||||
logger = configure_logging("DEBUG")
|
||||
logger.info("Application started")
|
||||
```
|
||||
"""
|
||||
formatter = logging.Formatter(fmt)
|
||||
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.handlers.clear()
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(level)
|
||||
|
||||
for name in UVICORN_LOGGERS:
|
||||
uv_logger = logging.getLogger(name)
|
||||
uv_logger.handlers.clear()
|
||||
uv_logger.addHandler(handler)
|
||||
uv_logger.setLevel(level)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
|
||||
def get_logger(name: str | None = _SENTINEL) -> logging.Logger: # type: ignore[assignment]
|
||||
"""Return a logger with the given *name*.
|
||||
|
||||
A thin convenience wrapper around :func:`logging.getLogger` that keeps
|
||||
logging imports consistent across the codebase.
|
||||
|
||||
When called without arguments, the caller's ``__name__`` is used
|
||||
automatically, so ``get_logger()`` in a module is equivalent to
|
||||
``logging.getLogger(__name__)``. Pass ``None`` explicitly to get the
|
||||
root logger.
|
||||
|
||||
Args:
|
||||
name: Logger name. Defaults to the caller's ``__name__``.
|
||||
Pass ``None`` to get the root logger.
|
||||
|
||||
Returns:
|
||||
A Logger instance.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from fastapi_toolsets.logger import get_logger
|
||||
|
||||
logger = get_logger() # uses caller's __name__
|
||||
logger = get_logger("myapp") # explicit name
|
||||
logger = get_logger(None) # root logger
|
||||
```
|
||||
"""
|
||||
if name is _SENTINEL:
|
||||
name = sys._getframe(1).f_globals.get("__name__")
|
||||
return logging.getLogger(name)
|
||||
21
src/fastapi_toolsets/metrics/__init__.py
Normal file
21
src/fastapi_toolsets/metrics/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Prometheus metrics integration for FastAPI applications."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .registry import Metric, MetricsRegistry
|
||||
|
||||
try:
|
||||
from .handler import init_metrics
|
||||
except ImportError:
|
||||
|
||||
def init_metrics(*_args: Any, **_kwargs: Any) -> None:
|
||||
from .._imports import require_extra
|
||||
|
||||
require_extra(package="prometheus_client", extra="metrics")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Metric",
|
||||
"MetricsRegistry",
|
||||
"init_metrics",
|
||||
]
|
||||
81
src/fastapi_toolsets/metrics/handler.py
Normal file
81
src/fastapi_toolsets/metrics/handler.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Prometheus metrics endpoint for FastAPI applications."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import Response
|
||||
from prometheus_client import (
|
||||
CONTENT_TYPE_LATEST,
|
||||
CollectorRegistry,
|
||||
generate_latest,
|
||||
multiprocess,
|
||||
)
|
||||
|
||||
from ..logger import get_logger
|
||||
from .registry import MetricsRegistry
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def _is_multiprocess() -> bool:
|
||||
"""Check if prometheus multi-process mode is enabled."""
|
||||
return "PROMETHEUS_MULTIPROC_DIR" in os.environ
|
||||
|
||||
|
||||
def init_metrics(
|
||||
app: FastAPI,
|
||||
registry: MetricsRegistry,
|
||||
*,
|
||||
path: str = "/metrics",
|
||||
) -> FastAPI:
|
||||
"""Register a Prometheus ``/metrics`` endpoint on a FastAPI app.
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance.
|
||||
registry: A :class:`MetricsRegistry` containing providers and collectors.
|
||||
path: URL path for the metrics endpoint (default ``/metrics``).
|
||||
|
||||
Returns:
|
||||
The same FastAPI instance (for chaining).
|
||||
|
||||
Example:
|
||||
```python
|
||||
from fastapi import FastAPI
|
||||
from fastapi_toolsets.metrics import MetricsRegistry, init_metrics
|
||||
|
||||
metrics = MetricsRegistry()
|
||||
app = FastAPI()
|
||||
init_metrics(app, registry=metrics)
|
||||
```
|
||||
"""
|
||||
for provider in registry.get_providers():
|
||||
logger.debug("Initialising metric provider '%s'", provider.name)
|
||||
registry._instances[provider.name] = provider.func()
|
||||
|
||||
# Partition collectors and cache env check at startup — both are stable for the app lifetime.
|
||||
async_collectors = [
|
||||
c for c in registry.get_collectors() if asyncio.iscoroutinefunction(c.func)
|
||||
]
|
||||
sync_collectors = [
|
||||
c for c in registry.get_collectors() if not asyncio.iscoroutinefunction(c.func)
|
||||
]
|
||||
multiprocess_mode = _is_multiprocess()
|
||||
|
||||
@app.get(path, include_in_schema=False)
|
||||
async def metrics_endpoint() -> Response:
|
||||
for collector in sync_collectors:
|
||||
collector.func()
|
||||
for collector in async_collectors:
|
||||
await collector.func()
|
||||
|
||||
if multiprocess_mode:
|
||||
prom_registry = CollectorRegistry()
|
||||
multiprocess.MultiProcessCollector(prom_registry)
|
||||
output = generate_latest(prom_registry)
|
||||
else:
|
||||
output = generate_latest()
|
||||
|
||||
return Response(content=output, media_type=CONTENT_TYPE_LATEST)
|
||||
|
||||
return app
|
||||
104
src/fastapi_toolsets/metrics/registry.py
Normal file
104
src/fastapi_toolsets/metrics/registry.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Metrics registry with decorator-based registration."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, cast
|
||||
|
||||
from ..logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Metric:
|
||||
"""A metric definition with metadata."""
|
||||
|
||||
name: str
|
||||
func: Callable[..., Any]
|
||||
collect: bool = field(default=False)
|
||||
|
||||
|
||||
class MetricsRegistry:
|
||||
"""Registry for managing Prometheus metric providers and collectors."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._metrics: dict[str, Metric] = {}
|
||||
self._instances: dict[str, Any] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
func: Callable[..., Any] | None = None,
|
||||
*,
|
||||
name: str | None = None,
|
||||
collect: bool = False,
|
||||
) -> Callable[..., Any]:
|
||||
"""Register a metric provider or collector function.
|
||||
|
||||
Can be used as a decorator with or without arguments.
|
||||
|
||||
Args:
|
||||
func: The metric function to register.
|
||||
name: Metric name (defaults to function name).
|
||||
collect: If ``True``, the function is called on every scrape.
|
||||
If ``False`` (default), called once at init time.
|
||||
"""
|
||||
|
||||
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||
metric_name = name or cast(Any, fn).__name__
|
||||
self._metrics[metric_name] = Metric(
|
||||
name=metric_name,
|
||||
func=fn,
|
||||
collect=collect,
|
||||
)
|
||||
return fn
|
||||
|
||||
if func is not None:
|
||||
return decorator(func)
|
||||
return decorator
|
||||
|
||||
def get(self, name: str) -> Any:
|
||||
"""Return the metric instance created by a provider.
|
||||
|
||||
Args:
|
||||
name: The metric name (defaults to the provider function name).
|
||||
|
||||
Raises:
|
||||
KeyError: If the metric name is unknown or ``init_metrics`` has not
|
||||
been called yet.
|
||||
"""
|
||||
if name not in self._instances:
|
||||
if name in self._metrics:
|
||||
raise KeyError(
|
||||
f"Metric '{name}' exists but has not been initialized yet. "
|
||||
"Ensure init_metrics() has been called before accessing metric instances."
|
||||
)
|
||||
raise KeyError(f"Unknown metric '{name}'.")
|
||||
return self._instances[name]
|
||||
|
||||
def include_registry(self, registry: "MetricsRegistry") -> None:
|
||||
"""Include another :class:`MetricsRegistry` into this one.
|
||||
|
||||
Args:
|
||||
registry: The registry to merge in.
|
||||
|
||||
Raises:
|
||||
ValueError: If a metric name already exists in the current registry.
|
||||
"""
|
||||
for metric_name, definition in registry._metrics.items():
|
||||
if metric_name in self._metrics:
|
||||
raise ValueError(
|
||||
f"Metric '{metric_name}' already exists in the current registry"
|
||||
)
|
||||
self._metrics[metric_name] = definition
|
||||
|
||||
def get_all(self) -> list[Metric]:
|
||||
"""Get all registered metric definitions."""
|
||||
return list(self._metrics.values())
|
||||
|
||||
def get_providers(self) -> list[Metric]:
|
||||
"""Get metric providers (called once at init)."""
|
||||
return [m for m in self._metrics.values() if not m.collect]
|
||||
|
||||
def get_collectors(self) -> list[Metric]:
|
||||
"""Get collectors (called on each scrape)."""
|
||||
return [m for m in self._metrics.values() if m.collect]
|
||||
21
src/fastapi_toolsets/models/__init__.py
Normal file
21
src/fastapi_toolsets/models/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""SQLAlchemy model mixins for common column patterns."""
|
||||
|
||||
from .columns import (
|
||||
CreatedAtMixin,
|
||||
TimestampMixin,
|
||||
UUIDMixin,
|
||||
UUIDv7Mixin,
|
||||
UpdatedAtMixin,
|
||||
)
|
||||
from .watched import ModelEvent, WatchedFieldsMixin, watch
|
||||
|
||||
__all__ = [
|
||||
"ModelEvent",
|
||||
"UUIDMixin",
|
||||
"UUIDv7Mixin",
|
||||
"CreatedAtMixin",
|
||||
"UpdatedAtMixin",
|
||||
"TimestampMixin",
|
||||
"WatchedFieldsMixin",
|
||||
"watch",
|
||||
]
|
||||
58
src/fastapi_toolsets/models/columns.py
Normal file
58
src/fastapi_toolsets/models/columns.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""SQLAlchemy column mixins for common column patterns."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, Uuid, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
__all__ = [
|
||||
"UUIDMixin",
|
||||
"UUIDv7Mixin",
|
||||
"CreatedAtMixin",
|
||||
"UpdatedAtMixin",
|
||||
"TimestampMixin",
|
||||
]
|
||||
|
||||
|
||||
class UUIDMixin:
|
||||
"""Mixin that adds a UUID primary key auto-generated by the database."""
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid,
|
||||
primary_key=True,
|
||||
server_default=text("gen_random_uuid()"),
|
||||
)
|
||||
|
||||
|
||||
class UUIDv7Mixin:
|
||||
"""Mixin that adds a UUIDv7 primary key auto-generated by the database."""
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid,
|
||||
primary_key=True,
|
||||
server_default=text("uuidv7()"),
|
||||
)
|
||||
|
||||
|
||||
class CreatedAtMixin:
|
||||
"""Mixin that adds a ``created_at`` timestamp column."""
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=text("clock_timestamp()"),
|
||||
)
|
||||
|
||||
|
||||
class UpdatedAtMixin:
|
||||
"""Mixin that adds an ``updated_at`` timestamp column."""
|
||||
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=text("clock_timestamp()"),
|
||||
onupdate=text("clock_timestamp()"),
|
||||
)
|
||||
|
||||
|
||||
class TimestampMixin(CreatedAtMixin, UpdatedAtMixin):
|
||||
"""Mixin that combines ``created_at`` and ``updated_at`` timestamp columns."""
|
||||
269
src/fastapi_toolsets/models/watched.py
Normal file
269
src/fastapi_toolsets/models/watched.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""Field-change monitoring via SQLAlchemy session events."""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import weakref
|
||||
from collections.abc import Awaitable
|
||||
from enum import Enum
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import set_committed_value as _sa_set_committed_value
|
||||
|
||||
from ..logger import get_logger
|
||||
|
||||
__all__ = ["ModelEvent", "WatchedFieldsMixin", "watch"]
|
||||
|
||||
_logger = get_logger()
|
||||
_T = TypeVar("_T")
|
||||
_CALLBACK_ERROR_MSG = "WatchedFieldsMixin callback raised an unhandled exception"
|
||||
_WATCHED_FIELDS: weakref.WeakKeyDictionary[type, list[str]] = (
|
||||
weakref.WeakKeyDictionary()
|
||||
)
|
||||
_SESSION_PENDING_NEW = "_ft_pending_new"
|
||||
_SESSION_CREATES = "_ft_creates"
|
||||
_SESSION_DELETES = "_ft_deletes"
|
||||
_SESSION_UPDATES = "_ft_updates"
|
||||
_SESSION_SAVEPOINT_DEPTH = "_ft_sp_depth"
|
||||
|
||||
|
||||
class ModelEvent(str, Enum):
|
||||
"""Event types emitted by :class:`WatchedFieldsMixin`."""
|
||||
|
||||
CREATE = "create"
|
||||
DELETE = "delete"
|
||||
UPDATE = "update"
|
||||
|
||||
|
||||
def watch(*fields: str) -> Any:
|
||||
"""Class decorator to filter which fields trigger ``on_update``.
|
||||
|
||||
Args:
|
||||
*fields: One or more field names to watch. At least one name is required.
|
||||
|
||||
Raises:
|
||||
ValueError: If called with no field names.
|
||||
"""
|
||||
if not fields:
|
||||
raise ValueError("@watch requires at least one field name.")
|
||||
|
||||
def decorator(cls: type[_T]) -> type[_T]:
|
||||
_WATCHED_FIELDS[cls] = list(fields)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
|
||||
"""Read currently-loaded column values into a plain dict."""
|
||||
state = sa_inspect(obj) # InstanceState
|
||||
state_dict = state.dict
|
||||
return {
|
||||
prop.key: state_dict[prop.key]
|
||||
for prop in state.mapper.column_attrs
|
||||
if prop.key in state_dict
|
||||
}
|
||||
|
||||
|
||||
def _get_watched_fields(cls: type) -> list[str] | None:
|
||||
"""Return the watched fields for *cls*, walking the MRO to inherit from parents."""
|
||||
for klass in cls.__mro__:
|
||||
if klass in _WATCHED_FIELDS:
|
||||
return _WATCHED_FIELDS[klass]
|
||||
return None
|
||||
|
||||
|
||||
def _upsert_changes(
|
||||
pending: dict[int, tuple[Any, dict[str, dict[str, Any]]]],
|
||||
obj: Any,
|
||||
changes: dict[str, dict[str, Any]],
|
||||
) -> None:
|
||||
"""Insert or merge *changes* into *pending* for *obj*."""
|
||||
key = id(obj)
|
||||
if key in pending:
|
||||
existing = pending[key][1]
|
||||
for field, change in changes.items():
|
||||
if field in existing:
|
||||
existing[field]["new"] = change["new"]
|
||||
else:
|
||||
existing[field] = change
|
||||
else:
|
||||
pending[key] = (obj, changes)
|
||||
|
||||
|
||||
@event.listens_for(AsyncSession.sync_session_class, "after_transaction_create")
|
||||
def _after_transaction_create(session: Any, transaction: Any) -> None:
|
||||
if transaction.nested:
|
||||
session.info[_SESSION_SAVEPOINT_DEPTH] = (
|
||||
session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) + 1
|
||||
)
|
||||
|
||||
|
||||
@event.listens_for(AsyncSession.sync_session_class, "after_transaction_end")
|
||||
def _after_transaction_end(session: Any, transaction: Any) -> None:
|
||||
if transaction.nested:
|
||||
depth = session.info.get(_SESSION_SAVEPOINT_DEPTH, 0)
|
||||
if depth > 0: # pragma: no branch
|
||||
session.info[_SESSION_SAVEPOINT_DEPTH] = depth - 1
|
||||
|
||||
|
||||
@event.listens_for(AsyncSession.sync_session_class, "after_flush")
|
||||
def _after_flush(session: Any, flush_context: Any) -> None:
|
||||
# New objects: capture references while session.new is still populated.
|
||||
# Values are read in _after_flush_postexec once RETURNING has been processed.
|
||||
for obj in session.new:
|
||||
if isinstance(obj, WatchedFieldsMixin):
|
||||
session.info.setdefault(_SESSION_PENDING_NEW, []).append(obj)
|
||||
|
||||
# Deleted objects: capture before they leave the identity map.
|
||||
for obj in session.deleted:
|
||||
if isinstance(obj, WatchedFieldsMixin):
|
||||
session.info.setdefault(_SESSION_DELETES, []).append(obj)
|
||||
|
||||
# Dirty objects: read old/new from SQLAlchemy attribute history.
|
||||
for obj in session.dirty:
|
||||
if not isinstance(obj, WatchedFieldsMixin):
|
||||
continue
|
||||
|
||||
# None = not in dict = watch all fields; list = specific fields only
|
||||
watched = _get_watched_fields(type(obj))
|
||||
changes: dict[str, dict[str, Any]] = {}
|
||||
|
||||
attrs = (
|
||||
# Specific fields
|
||||
((field, sa_inspect(obj).attrs[field]) for field in watched)
|
||||
if watched is not None
|
||||
# All mapped fields
|
||||
else ((s.key, s) for s in sa_inspect(obj).attrs)
|
||||
)
|
||||
for field, attr_state in attrs:
|
||||
history = attr_state.history
|
||||
if history.has_changes() and history.deleted:
|
||||
changes[field] = {
|
||||
"old": history.deleted[0],
|
||||
"new": history.added[0] if history.added else None,
|
||||
}
|
||||
|
||||
if changes:
|
||||
_upsert_changes(
|
||||
session.info.setdefault(_SESSION_UPDATES, {}),
|
||||
obj,
|
||||
changes,
|
||||
)
|
||||
|
||||
|
||||
@event.listens_for(AsyncSession.sync_session_class, "after_flush_postexec")
|
||||
def _after_flush_postexec(session: Any, flush_context: Any) -> None:
|
||||
# New objects are now persistent and RETURNING values have been applied,
|
||||
# so server defaults (id, created_at, …) are available via getattr.
|
||||
pending_new: list[Any] = session.info.pop(_SESSION_PENDING_NEW, [])
|
||||
if not pending_new:
|
||||
return
|
||||
session.info.setdefault(_SESSION_CREATES, []).extend(pending_new)
|
||||
|
||||
|
||||
@event.listens_for(AsyncSession.sync_session_class, "after_rollback")
|
||||
def _after_rollback(session: Any) -> None:
|
||||
session.info.pop(_SESSION_PENDING_NEW, None)
|
||||
session.info.pop(_SESSION_CREATES, None)
|
||||
session.info.pop(_SESSION_DELETES, None)
|
||||
session.info.pop(_SESSION_UPDATES, None)
|
||||
|
||||
|
||||
def _task_error_handler(task: asyncio.Task[Any]) -> None:
|
||||
if not task.cancelled() and (exc := task.exception()):
|
||||
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||
|
||||
|
||||
def _schedule_with_snapshot(
|
||||
loop: asyncio.AbstractEventLoop, obj: Any, fn: Any, *args: Any
|
||||
) -> None:
|
||||
"""Snapshot *obj*'s column attrs now (before expire_on_commit wipes them),
|
||||
then schedule a coroutine that restores the snapshot and calls *fn*.
|
||||
"""
|
||||
snapshot = _snapshot_column_attrs(obj)
|
||||
|
||||
async def _run(
|
||||
obj: Any = obj,
|
||||
fn: Any = fn,
|
||||
snapshot: dict[str, Any] = snapshot,
|
||||
args: tuple = args,
|
||||
) -> None:
|
||||
for key, value in snapshot.items():
|
||||
_sa_set_committed_value(obj, key, value)
|
||||
try:
|
||||
result = fn(*args)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
except Exception as exc:
|
||||
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||
|
||||
task = loop.create_task(_run())
|
||||
task.add_done_callback(_task_error_handler)
|
||||
|
||||
|
||||
@event.listens_for(AsyncSession.sync_session_class, "after_commit")
|
||||
def _after_commit(session: Any) -> None:
|
||||
if session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) > 0:
|
||||
return
|
||||
|
||||
creates: list[Any] = session.info.pop(_SESSION_CREATES, [])
|
||||
deletes: list[Any] = session.info.pop(_SESSION_DELETES, [])
|
||||
field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = session.info.pop(
|
||||
_SESSION_UPDATES, {}
|
||||
)
|
||||
|
||||
if creates and deletes:
|
||||
transient_ids = {id(o) for o in creates} & {id(o) for o in deletes}
|
||||
if transient_ids:
|
||||
creates = [o for o in creates if id(o) not in transient_ids]
|
||||
deletes = [o for o in deletes if id(o) not in transient_ids]
|
||||
field_changes = {
|
||||
k: v for k, v in field_changes.items() if k not in transient_ids
|
||||
}
|
||||
|
||||
if not creates and not deletes and not field_changes:
|
||||
return
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return
|
||||
|
||||
for obj in creates:
|
||||
_schedule_with_snapshot(loop, obj, obj.on_create)
|
||||
|
||||
for obj in deletes:
|
||||
_schedule_with_snapshot(loop, obj, obj.on_delete)
|
||||
|
||||
for obj, changes in field_changes.values():
|
||||
_schedule_with_snapshot(loop, obj, obj.on_update, changes)
|
||||
|
||||
|
||||
class WatchedFieldsMixin:
|
||||
"""Mixin that enables lifecycle callbacks for SQLAlchemy models."""
|
||||
|
||||
def on_event(
|
||||
self, event: ModelEvent, changes: dict[str, dict[str, Any]] | None = None
|
||||
) -> Awaitable[None] | None:
|
||||
"""Catch-all callback fired for every lifecycle event.
|
||||
|
||||
Args:
|
||||
event: The event type (:attr:`ModelEvent.CREATE`, :attr:`ModelEvent.DELETE`,
|
||||
or :attr:`ModelEvent.UPDATE`).
|
||||
changes: Field changes for :attr:`ModelEvent.UPDATE`, ``None`` otherwise.
|
||||
"""
|
||||
|
||||
def on_create(self) -> Awaitable[None] | None:
|
||||
"""Called after INSERT commit."""
|
||||
return self.on_event(ModelEvent.CREATE)
|
||||
|
||||
def on_delete(self) -> Awaitable[None] | None:
|
||||
"""Called after DELETE commit."""
|
||||
return self.on_event(ModelEvent.DELETE)
|
||||
|
||||
def on_update(self, changes: dict[str, dict[str, Any]]) -> Awaitable[None] | None:
|
||||
"""Called after UPDATE commit when watched fields change."""
|
||||
return self.on_event(ModelEvent.UPDATE, changes=changes)
|
||||
@@ -1,8 +1,30 @@
|
||||
from .plugin import register_fixtures
|
||||
from .utils import create_async_client, create_db_session
|
||||
"""Pytest helpers for FastAPI testing: sessions, clients, and fixtures."""
|
||||
|
||||
try:
|
||||
from .plugin import register_fixtures
|
||||
except ImportError:
|
||||
from .._imports import require_extra
|
||||
|
||||
require_extra(package="pytest", extra="pytest")
|
||||
|
||||
try:
|
||||
from .utils import (
|
||||
cleanup_tables,
|
||||
create_async_client,
|
||||
create_db_session,
|
||||
create_worker_database,
|
||||
worker_database_url,
|
||||
)
|
||||
except ImportError:
|
||||
from .._imports import require_extra
|
||||
|
||||
require_extra(package="httpx", extra="pytest")
|
||||
|
||||
__all__ = [
|
||||
"cleanup_tables",
|
||||
"create_async_client",
|
||||
"create_db_session",
|
||||
"create_worker_database",
|
||||
"register_fixtures",
|
||||
"worker_database_url",
|
||||
]
|
||||
|
||||
@@ -1,55 +1,4 @@
|
||||
"""Pytest plugin for using FixtureRegistry fixtures in tests.
|
||||
|
||||
This module provides utilities to automatically generate pytest fixtures
|
||||
from your FixtureRegistry, with proper dependency resolution.
|
||||
|
||||
Example:
|
||||
# conftest.py
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
from app.fixtures import fixtures # Your FixtureRegistry
|
||||
from app.models import Base
|
||||
from fastapi_toolsets.pytest_plugin import register_fixtures
|
||||
|
||||
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/test_db"
|
||||
|
||||
@pytest.fixture
|
||||
async def engine():
|
||||
engine = create_async_engine(DATABASE_URL)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(engine):
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
session = session_factory()
|
||||
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
# Automatically generate pytest fixtures from registry
|
||||
# Creates: fixture_roles, fixture_users, fixture_posts, etc.
|
||||
register_fixtures(fixtures, globals())
|
||||
|
||||
Usage in tests:
|
||||
# test_users.py
|
||||
async def test_user_count(db_session, fixture_users):
|
||||
# fixture_users automatically loads fixture_roles first (if dependency)
|
||||
# and returns the list of User models
|
||||
assert len(fixture_users) > 0
|
||||
|
||||
async def test_user_role(db_session, fixture_users):
|
||||
user = fixture_users[0]
|
||||
assert user.role_id is not None
|
||||
"""
|
||||
"""Pytest plugin for using FixtureRegistry fixtures in tests."""
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
@@ -86,6 +35,7 @@ def register_fixtures(
|
||||
List of created fixture names
|
||||
|
||||
Example:
|
||||
```python
|
||||
# conftest.py
|
||||
from app.fixtures import fixtures
|
||||
from fastapi_toolsets.pytest_plugin import register_fixtures
|
||||
@@ -96,6 +46,7 @@ def register_fixtures(
|
||||
# - fixture_roles
|
||||
# - fixture_users (depends on fixture_roles if users depends on roles)
|
||||
# - fixture_posts (depends on fixture_users if posts depends on users)
|
||||
```
|
||||
"""
|
||||
created_fixtures: list[str] = []
|
||||
|
||||
|
||||
@@ -1,26 +1,164 @@
|
||||
"""Pytest helper utilities for FastAPI testing."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from ..db import create_db_context
|
||||
from sqlalchemy import text
|
||||
|
||||
from ..db import (
|
||||
cleanup_tables as _cleanup_tables,
|
||||
create_database,
|
||||
create_db_context,
|
||||
)
|
||||
|
||||
|
||||
async def cleanup_tables(
|
||||
session: AsyncSession,
|
||||
base: type[DeclarativeBase],
|
||||
) -> None:
|
||||
"""Truncate all tables for fast between-test cleanup.
|
||||
|
||||
.. deprecated::
|
||||
Import ``cleanup_tables`` from ``fastapi_toolsets.db`` instead.
|
||||
This re-export will be removed in v3.0.0.
|
||||
"""
|
||||
warnings.warn(
|
||||
"Importing cleanup_tables from fastapi_toolsets.pytest is deprecated "
|
||||
"and will be removed in v3.0.0. "
|
||||
"Use 'from fastapi_toolsets.db import cleanup_tables' instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
await _cleanup_tables(session=session, base=base)
|
||||
|
||||
|
||||
def _get_xdist_worker(default_test_db: str) -> str:
|
||||
"""Return the pytest-xdist worker name, or *default_test_db* when not running under xdist.
|
||||
|
||||
Reads the ``PYTEST_XDIST_WORKER`` environment variable that xdist sets
|
||||
automatically in each worker process (e.g. ``"gw0"``, ``"gw1"``).
|
||||
When xdist is not installed or not active, the variable is absent and
|
||||
*default_test_db* is returned instead.
|
||||
|
||||
Args:
|
||||
default_test_db: Fallback value returned when ``PYTEST_XDIST_WORKER``
|
||||
is not set.
|
||||
"""
|
||||
return os.environ.get("PYTEST_XDIST_WORKER", default_test_db)
|
||||
|
||||
|
||||
def worker_database_url(database_url: str, default_test_db: str) -> str:
|
||||
"""Derive a per-worker database URL for pytest-xdist parallel runs.
|
||||
|
||||
Appends ``_{worker_name}`` to the database name so each xdist worker
|
||||
operates on its own database. When not running under xdist,
|
||||
``_{default_test_db}`` is appended instead.
|
||||
|
||||
The worker name is read from the ``PYTEST_XDIST_WORKER`` environment
|
||||
variable (set automatically by xdist in each worker process).
|
||||
|
||||
Args:
|
||||
database_url: Original database connection URL.
|
||||
default_test_db: Suffix appended to the database name when
|
||||
``PYTEST_XDIST_WORKER`` is not set.
|
||||
|
||||
Returns:
|
||||
A database URL with a worker- or default-specific database name.
|
||||
"""
|
||||
worker = _get_xdist_worker(default_test_db=default_test_db)
|
||||
|
||||
url = make_url(database_url)
|
||||
url = url.set(database=f"{url.database}_{worker}")
|
||||
return url.render_as_string(hide_password=False)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def create_worker_database(
|
||||
database_url: str,
|
||||
default_test_db: str = "test_db",
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Create and drop a per-worker database for pytest-xdist isolation.
|
||||
|
||||
Derives a worker-specific database URL using :func:`worker_database_url`,
|
||||
then delegates to :func:`~fastapi_toolsets.db.create_database` to create
|
||||
and drop it. Intended for use as a **session-scoped** fixture.
|
||||
|
||||
When running under xdist the database name is suffixed with the worker
|
||||
name (e.g. ``_gw0``). Otherwise it is suffixed with *default_test_db*.
|
||||
|
||||
Args:
|
||||
database_url: Original database connection URL (used as the server
|
||||
connection and as the base for the worker database name).
|
||||
default_test_db: Suffix appended to the database name when
|
||||
``PYTEST_XDIST_WORKER`` is not set. Defaults to ``"test_db"``.
|
||||
|
||||
Yields:
|
||||
The worker-specific database URL.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from fastapi_toolsets.pytest import create_worker_database, create_db_session
|
||||
|
||||
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost/test_db"
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def worker_db_url():
|
||||
async with create_worker_database(DATABASE_URL) as url:
|
||||
yield url
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(worker_db_url):
|
||||
async with create_db_session(
|
||||
worker_db_url, Base, cleanup=True
|
||||
) as session:
|
||||
yield session
|
||||
```
|
||||
"""
|
||||
worker_url = worker_database_url(
|
||||
database_url=database_url, default_test_db=default_test_db
|
||||
)
|
||||
worker_db_name: str = make_url(worker_url).database # type: ignore[assignment]
|
||||
|
||||
engine = create_async_engine(database_url, isolation_level="AUTOCOMMIT")
|
||||
try:
|
||||
async with engine.connect() as conn:
|
||||
await conn.execute(text(f"DROP DATABASE IF EXISTS {worker_db_name}"))
|
||||
await create_database(db_name=worker_db_name, server_url=database_url)
|
||||
|
||||
yield worker_url
|
||||
|
||||
async with engine.connect() as conn:
|
||||
await conn.execute(text(f"DROP DATABASE IF EXISTS {worker_db_name}"))
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def create_async_client(
|
||||
app: Any,
|
||||
base_url: str = "http://test",
|
||||
dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] | None = None,
|
||||
) -> AsyncGenerator[AsyncClient, None]:
|
||||
"""Create an async httpx client for testing FastAPI applications.
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance.
|
||||
base_url: Base URL for requests. Defaults to "http://test".
|
||||
dependency_overrides: Optional mapping of original dependencies to
|
||||
their test replacements. Applied via ``app.dependency_overrides``
|
||||
before yielding and cleaned up after.
|
||||
|
||||
Yields:
|
||||
An AsyncClient configured for the app.
|
||||
@@ -41,10 +179,39 @@ async def create_async_client(
|
||||
response = await client.get("/health")
|
||||
assert response.status_code == 200
|
||||
```
|
||||
|
||||
Example with dependency overrides:
|
||||
```python
|
||||
from fastapi_toolsets.pytest import create_async_client, create_db_session
|
||||
from app.db import get_db
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session():
|
||||
async with create_db_session(DATABASE_URL, Base, cleanup=True) as session:
|
||||
yield session
|
||||
|
||||
@pytest.fixture
|
||||
async def client(db_session):
|
||||
async def override():
|
||||
yield db_session
|
||||
|
||||
async with create_async_client(
|
||||
app, dependency_overrides={get_db: override}
|
||||
) as c:
|
||||
yield c
|
||||
```
|
||||
"""
|
||||
if dependency_overrides:
|
||||
app.dependency_overrides.update(dependency_overrides)
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url=base_url) as client:
|
||||
yield client
|
||||
try:
|
||||
async with AsyncClient(transport=transport, base_url=base_url) as client:
|
||||
yield client
|
||||
finally:
|
||||
if dependency_overrides:
|
||||
for key in dependency_overrides:
|
||||
app.dependency_overrides.pop(key, None)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -55,6 +222,7 @@ async def create_db_session(
|
||||
echo: bool = False,
|
||||
expire_on_commit: bool = False,
|
||||
drop_tables: bool = True,
|
||||
cleanup: bool = False,
|
||||
) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create a database session for testing.
|
||||
|
||||
@@ -67,6 +235,8 @@ async def create_db_session(
|
||||
echo: Enable SQLAlchemy query logging. Defaults to False.
|
||||
expire_on_commit: Expire objects after commit. Defaults to False.
|
||||
drop_tables: Drop tables after test. Defaults to True.
|
||||
cleanup: Truncate all tables after test using
|
||||
:func:`cleanup_tables`. Defaults to False.
|
||||
|
||||
Yields:
|
||||
An AsyncSession ready for database operations.
|
||||
@@ -80,7 +250,9 @@ async def create_db_session(
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session():
|
||||
async with create_db_session(DATABASE_URL, Base) as session:
|
||||
async with create_db_session(
|
||||
DATABASE_URL, Base, cleanup=True
|
||||
) as session:
|
||||
yield session
|
||||
|
||||
async def test_create_user(db_session: AsyncSession):
|
||||
@@ -103,6 +275,9 @@ async def create_db_session(
|
||||
async with get_session() as session:
|
||||
yield session
|
||||
|
||||
if cleanup:
|
||||
await cleanup_tables(session, base)
|
||||
|
||||
if drop_tables:
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(base.metadata.drop_all)
|
||||
|
||||
@@ -1,21 +1,27 @@
|
||||
"""Base Pydantic schemas for API responses."""
|
||||
|
||||
import math
|
||||
from enum import Enum
|
||||
from typing import ClassVar, Generic, TypeVar
|
||||
from typing import Annotated, Any, ClassVar, Generic, Literal, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field, computed_field
|
||||
|
||||
from .types import DataT
|
||||
|
||||
__all__ = [
|
||||
"ApiError",
|
||||
"CursorPagination",
|
||||
"CursorPaginatedResponse",
|
||||
"ErrorResponse",
|
||||
"Pagination",
|
||||
"OffsetPagination",
|
||||
"OffsetPaginatedResponse",
|
||||
"PaginatedResponse",
|
||||
"PaginationType",
|
||||
"PydanticBase",
|
||||
"Response",
|
||||
"ResponseStatus",
|
||||
]
|
||||
|
||||
DataT = TypeVar("DataT")
|
||||
|
||||
|
||||
class PydanticBase(BaseModel):
|
||||
"""Base class for all Pydantic models with common configuration."""
|
||||
@@ -49,6 +55,7 @@ class ApiError(PydanticBase):
|
||||
msg: str
|
||||
desc: str
|
||||
err_code: str
|
||||
data: Any | None = None
|
||||
|
||||
|
||||
class BaseResponse(PydanticBase):
|
||||
@@ -69,7 +76,9 @@ class Response(BaseResponse, Generic[DataT]):
|
||||
"""Generic API response with data payload.
|
||||
|
||||
Example:
|
||||
```python
|
||||
Response[UserRead](data=user, message="User retrieved")
|
||||
```
|
||||
"""
|
||||
|
||||
data: DataT | None = None
|
||||
@@ -83,34 +92,113 @@ class ErrorResponse(BaseResponse):
|
||||
|
||||
status: ResponseStatus = ResponseStatus.FAIL
|
||||
description: str | None = None
|
||||
data: None = None
|
||||
data: Any | None = None
|
||||
|
||||
|
||||
class Pagination(PydanticBase):
|
||||
"""Pagination metadata for list responses.
|
||||
class OffsetPagination(PydanticBase):
|
||||
"""Pagination metadata for offset-based list responses.
|
||||
|
||||
Attributes:
|
||||
total_count: Total number of items across all pages
|
||||
total_count: Total number of items across all pages.
|
||||
``None`` when ``include_total=False``.
|
||||
items_per_page: Number of items per page
|
||||
page: Current page number (1-indexed)
|
||||
has_more: Whether there are more pages
|
||||
pages: Total number of pages
|
||||
"""
|
||||
|
||||
total_count: int
|
||||
total_count: int | None
|
||||
items_per_page: int
|
||||
page: int
|
||||
has_more: bool
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def pages(self) -> int | None:
|
||||
"""Total number of pages, or ``None`` when ``total_count`` is unknown."""
|
||||
if self.total_count is None:
|
||||
return None
|
||||
if self.items_per_page == 0:
|
||||
return 0
|
||||
return math.ceil(self.total_count / self.items_per_page)
|
||||
|
||||
|
||||
class CursorPagination(PydanticBase):
|
||||
"""Pagination metadata for cursor-based list responses.
|
||||
|
||||
Attributes:
|
||||
next_cursor: Encoded cursor for the next page, or None on the last page.
|
||||
prev_cursor: Encoded cursor for the previous page, or None on the first page.
|
||||
items_per_page: Number of items requested per page.
|
||||
has_more: Whether there is at least one more page after this one.
|
||||
"""
|
||||
|
||||
next_cursor: str | None
|
||||
prev_cursor: str | None = None
|
||||
items_per_page: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class PaginationType(str, Enum):
|
||||
"""Pagination strategy selector for :meth:`.AsyncCrud.paginate`."""
|
||||
|
||||
OFFSET = "offset"
|
||||
CURSOR = "cursor"
|
||||
|
||||
|
||||
class PaginatedResponse(BaseResponse, Generic[DataT]):
|
||||
"""Paginated API response for list endpoints.
|
||||
|
||||
Example:
|
||||
PaginatedResponse[UserRead](
|
||||
data=users,
|
||||
pagination=Pagination(total_count=100, items_per_page=10, page=1, has_more=True)
|
||||
)
|
||||
Base class and return type for endpoints that support both pagination
|
||||
strategies. Use :class:`OffsetPaginatedResponse` or
|
||||
:class:`CursorPaginatedResponse` when the strategy is fixed.
|
||||
|
||||
When used as ``PaginatedResponse[T]`` in a return annotation, subscripting
|
||||
returns ``Annotated[Union[CursorPaginatedResponse[T], OffsetPaginatedResponse[T]], Field(discriminator="pagination_type")]``
|
||||
so FastAPI emits a proper ``oneOf`` + discriminator in the OpenAPI schema.
|
||||
"""
|
||||
|
||||
data: list[DataT]
|
||||
pagination: Pagination
|
||||
pagination: OffsetPagination | CursorPagination
|
||||
pagination_type: PaginationType | None = None
|
||||
filter_attributes: dict[str, list[Any]] | None = None
|
||||
|
||||
_discriminated_union_cache: ClassVar[dict[Any, Any]] = {}
|
||||
|
||||
def __class_getitem__( # type: ignore[invalid-method-override]
|
||||
cls, item: type[Any] | tuple[type[Any], ...]
|
||||
) -> type[Any]:
|
||||
if cls is PaginatedResponse and not isinstance(item, TypeVar):
|
||||
cached = cls._discriminated_union_cache.get(item)
|
||||
if cached is None:
|
||||
cached = Annotated[
|
||||
Union[CursorPaginatedResponse[item], OffsetPaginatedResponse[item]], # type: ignore[invalid-type-form]
|
||||
Field(discriminator="pagination_type"),
|
||||
]
|
||||
cls._discriminated_union_cache[item] = cached
|
||||
return cached # type: ignore[invalid-return-type]
|
||||
return super().__class_getitem__(item)
|
||||
|
||||
|
||||
class OffsetPaginatedResponse(PaginatedResponse[DataT]):
|
||||
"""Paginated response with typed offset-based pagination metadata.
|
||||
|
||||
The ``pagination_type`` field is always ``"offset"`` and acts as a
|
||||
discriminator, allowing frontend clients to narrow the union type returned
|
||||
by a unified ``paginate()`` endpoint.
|
||||
"""
|
||||
|
||||
pagination: OffsetPagination
|
||||
pagination_type: Literal[PaginationType.OFFSET] = PaginationType.OFFSET
|
||||
|
||||
|
||||
class CursorPaginatedResponse(PaginatedResponse[DataT]):
|
||||
"""Paginated response with typed cursor-based pagination metadata.
|
||||
|
||||
The ``pagination_type`` field is always ``"cursor"`` and acts as a
|
||||
discriminator, allowing frontend clients to narrow the union type returned
|
||||
by a unified ``paginate()`` endpoint.
|
||||
"""
|
||||
|
||||
pagination: CursorPagination
|
||||
pagination_type: Literal[PaginationType.CURSOR] = PaginationType.CURSOR
|
||||
|
||||
27
src/fastapi_toolsets/types.py
Normal file
27
src/fastapi_toolsets/types.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Shared type aliases for the fastapi-toolsets package."""
|
||||
|
||||
from collections.abc import AsyncGenerator, Callable, Mapping
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
|
||||
# Generic TypeVars
|
||||
DataT = TypeVar("DataT")
|
||||
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
||||
SchemaType = TypeVar("SchemaType", bound=BaseModel)
|
||||
|
||||
# CRUD type aliases
|
||||
JoinType = list[tuple[type[DeclarativeBase], Any]]
|
||||
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
|
||||
OrderByClause = ColumnElement[Any] | QueryableAttribute[Any]
|
||||
|
||||
# Search / facet type aliases
|
||||
SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...]
|
||||
FacetFieldType = SearchFieldType
|
||||
|
||||
# Dependency type aliases
|
||||
SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]] | Any
|
||||
@@ -1,27 +1,36 @@
|
||||
"""Shared pytest fixtures for fastapi-utils tests."""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import ForeignKey, String
|
||||
import datetime
|
||||
import decimal
|
||||
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Date,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
Numeric,
|
||||
String,
|
||||
Table,
|
||||
Uuid,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||
|
||||
from fastapi_toolsets.crud import CrudFactory
|
||||
from fastapi_toolsets.schemas import PydanticBase
|
||||
|
||||
# PostgreSQL connection URL from environment or default for local development
|
||||
DATABASE_URL = os.getenv("DATABASE_URL") or os.getenv(
|
||||
"TEST_DATABASE_URL",
|
||||
"postgresql+asyncpg://postgres:postgres@localhost:5432/fastapi_toolsets_test",
|
||||
DATABASE_URL = os.getenv(
|
||||
key="DATABASE_URL",
|
||||
default="postgresql+asyncpg://postgres:postgres@localhost:5432/postgres",
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for test models."""
|
||||
|
||||
@@ -33,7 +42,7 @@ class Role(Base):
|
||||
|
||||
__tablename__ = "roles"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||
name: Mapped[str] = mapped_column(String(50), unique=True)
|
||||
|
||||
users: Mapped[list["User"]] = relationship(back_populates="role")
|
||||
@@ -44,36 +53,100 @@ class User(Base):
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||
username: Mapped[str] = mapped_column(String(50), unique=True)
|
||||
email: Mapped[str] = mapped_column(String(100), unique=True)
|
||||
is_active: Mapped[bool] = mapped_column(default=True)
|
||||
role_id: Mapped[int | None] = mapped_column(ForeignKey("roles.id"), nullable=True)
|
||||
role_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
ForeignKey("roles.id"), nullable=True
|
||||
)
|
||||
|
||||
role: Mapped[Role | None] = relationship(back_populates="users")
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
"""Test tag model."""
|
||||
|
||||
__tablename__ = "tags"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||
name: Mapped[str] = mapped_column(String(50), unique=True)
|
||||
|
||||
|
||||
post_tags = Table(
|
||||
"post_tags",
|
||||
Base.metadata,
|
||||
Column(
|
||||
"post_id", Uuid, ForeignKey("posts.id", ondelete="CASCADE"), primary_key=True
|
||||
),
|
||||
Column("tag_id", Uuid, ForeignKey("tags.id", ondelete="CASCADE"), primary_key=True),
|
||||
)
|
||||
|
||||
|
||||
class IntRole(Base):
|
||||
"""Test role model with auto-increment integer PK."""
|
||||
|
||||
__tablename__ = "int_roles"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
name: Mapped[str] = mapped_column(String(50), unique=True)
|
||||
|
||||
|
||||
class Permission(Base):
|
||||
"""Test model with composite primary key."""
|
||||
|
||||
__tablename__ = "permissions"
|
||||
|
||||
subject: Mapped[str] = mapped_column(String(50), primary_key=True)
|
||||
action: Mapped[str] = mapped_column(String(50), primary_key=True)
|
||||
|
||||
|
||||
class Event(Base):
|
||||
"""Test model with DateTime and Date cursor columns."""
|
||||
|
||||
__tablename__ = "events"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||
name: Mapped[str] = mapped_column(String(100))
|
||||
occurred_at: Mapped[datetime.datetime] = mapped_column(DateTime)
|
||||
scheduled_date: Mapped[datetime.date] = mapped_column(Date)
|
||||
|
||||
|
||||
class Product(Base):
|
||||
"""Test model with Numeric cursor column."""
|
||||
|
||||
__tablename__ = "products"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||
name: Mapped[str] = mapped_column(String(100))
|
||||
price: Mapped[decimal.Decimal] = mapped_column(Numeric(10, 2))
|
||||
|
||||
|
||||
class Post(Base):
|
||||
"""Test post model."""
|
||||
|
||||
__tablename__ = "posts"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||
title: Mapped[str] = mapped_column(String(200))
|
||||
content: Mapped[str] = mapped_column(String(1000), default="")
|
||||
is_published: Mapped[bool] = mapped_column(default=False)
|
||||
author_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
|
||||
author_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Schemas
|
||||
# =============================================================================
|
||||
tags: Mapped[list[Tag]] = relationship(secondary=post_tags)
|
||||
|
||||
|
||||
class RoleCreate(BaseModel):
|
||||
"""Schema for creating a role."""
|
||||
|
||||
id: int | None = None
|
||||
id: uuid.UUID | None = None
|
||||
name: str
|
||||
|
||||
|
||||
class RoleRead(PydanticBase):
|
||||
"""Schema for reading a role."""
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
|
||||
|
||||
@@ -86,11 +159,19 @@ class RoleUpdate(BaseModel):
|
||||
class UserCreate(BaseModel):
|
||||
"""Schema for creating a user."""
|
||||
|
||||
id: int | None = None
|
||||
id: uuid.UUID | None = None
|
||||
username: str
|
||||
email: str
|
||||
is_active: bool = True
|
||||
role_id: int | None = None
|
||||
role_id: uuid.UUID | None = None
|
||||
|
||||
|
||||
class UserRead(PydanticBase):
|
||||
"""Schema for reading a user (subset of fields — no email)."""
|
||||
|
||||
id: uuid.UUID
|
||||
username: str
|
||||
is_active: bool = True
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
@@ -99,17 +180,24 @@ class UserUpdate(BaseModel):
|
||||
username: str | None = None
|
||||
email: str | None = None
|
||||
is_active: bool | None = None
|
||||
role_id: int | None = None
|
||||
role_id: uuid.UUID | None = None
|
||||
|
||||
|
||||
class TagCreate(BaseModel):
|
||||
"""Schema for creating a tag."""
|
||||
|
||||
id: uuid.UUID | None = None
|
||||
name: str
|
||||
|
||||
|
||||
class PostCreate(BaseModel):
|
||||
"""Schema for creating a post."""
|
||||
|
||||
id: int | None = None
|
||||
id: uuid.UUID | None = None
|
||||
title: str
|
||||
content: str = ""
|
||||
is_published: bool = False
|
||||
author_id: int
|
||||
author_id: uuid.UUID
|
||||
|
||||
|
||||
class PostUpdate(BaseModel):
|
||||
@@ -120,18 +208,81 @@ class PostUpdate(BaseModel):
|
||||
is_published: bool | None = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CRUD Classes
|
||||
# =============================================================================
|
||||
class PostM2MCreate(BaseModel):
|
||||
"""Schema for creating a post with M2M tag IDs."""
|
||||
|
||||
id: uuid.UUID | None = None
|
||||
title: str
|
||||
content: str = ""
|
||||
is_published: bool = False
|
||||
author_id: uuid.UUID
|
||||
tag_ids: list[uuid.UUID] = []
|
||||
|
||||
|
||||
class PostM2MUpdate(BaseModel):
|
||||
"""Schema for updating a post with M2M tag IDs."""
|
||||
|
||||
title: str | None = None
|
||||
content: str | None = None
|
||||
is_published: bool | None = None
|
||||
tag_ids: list[uuid.UUID] | None = None
|
||||
|
||||
|
||||
class IntRoleRead(PydanticBase):
|
||||
"""Schema for reading an IntRole."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
|
||||
|
||||
class IntRoleCreate(BaseModel):
|
||||
"""Schema for creating an IntRole."""
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
class EventRead(PydanticBase):
|
||||
"""Schema for reading an Event."""
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
|
||||
|
||||
class EventCreate(BaseModel):
|
||||
"""Schema for creating an Event."""
|
||||
|
||||
name: str
|
||||
occurred_at: datetime.datetime
|
||||
scheduled_date: datetime.date
|
||||
|
||||
|
||||
class ProductRead(PydanticBase):
|
||||
"""Schema for reading a Product."""
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
|
||||
|
||||
class ProductCreate(BaseModel):
|
||||
"""Schema for creating a Product."""
|
||||
|
||||
name: str
|
||||
price: decimal.Decimal
|
||||
|
||||
|
||||
RoleCrud = CrudFactory(Role)
|
||||
RoleCursorCrud = CrudFactory(Role, cursor_column=Role.id)
|
||||
IntRoleCursorCrud = CrudFactory(IntRole, cursor_column=IntRole.id)
|
||||
UserCrud = CrudFactory(User)
|
||||
UserCursorCrud = CrudFactory(User, cursor_column=User.id)
|
||||
PostCrud = CrudFactory(Post)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
TagCrud = CrudFactory(Tag)
|
||||
PostM2MCrud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags})
|
||||
EventCrud = CrudFactory(Event)
|
||||
EventDateTimeCursorCrud = CrudFactory(Event, cursor_column=Event.occurred_at)
|
||||
EventDateCursorCrud = CrudFactory(Event, cursor_column=Event.scheduled_date)
|
||||
ProductCrud = CrudFactory(Product)
|
||||
ProductNumericCursorCrud = CrudFactory(Product, cursor_column=Product.price)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -170,30 +321,3 @@ async def db_session(engine):
|
||||
# Drop tables after test
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_role_data() -> RoleCreate:
|
||||
"""Sample role creation data."""
|
||||
return RoleCreate(name="admin")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_user_data() -> UserCreate:
|
||||
"""Sample user creation data."""
|
||||
return UserCreate(
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_post_data() -> PostCreate:
|
||||
"""Sample post creation data."""
|
||||
return PostCreate(
|
||||
title="Test Post",
|
||||
content="Test content",
|
||||
is_published=True,
|
||||
author_id=1,
|
||||
)
|
||||
|
||||
543
tests/test_cli.py
Normal file
543
tests/test_cli.py
Normal file
@@ -0,0 +1,543 @@
|
||||
"""Tests for fastapi_toolsets.cli module."""
|
||||
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from fastapi_toolsets.cli.config import (
|
||||
get_config_value,
|
||||
get_custom_cli,
|
||||
get_db_context,
|
||||
get_fixtures_registry,
|
||||
import_from_string,
|
||||
)
|
||||
from fastapi_toolsets.cli.pyproject import find_pyproject, load_pyproject
|
||||
from fastapi_toolsets.cli.utils import async_command
|
||||
from fastapi_toolsets.fixtures import FixtureRegistry
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
class TestPyproject:
|
||||
"""Tests for pyproject.toml discovery and loading."""
|
||||
|
||||
def test_find_pyproject_in_current_dir(self, tmp_path, monkeypatch):
|
||||
"""Finds pyproject.toml in current directory."""
|
||||
pyproject = tmp_path / "pyproject.toml"
|
||||
pyproject.write_text("[project]\nname = 'test'\n")
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
result = find_pyproject()
|
||||
assert result == pyproject
|
||||
|
||||
def test_find_pyproject_in_parent_dir(self, tmp_path, monkeypatch):
|
||||
"""Finds pyproject.toml in parent directory."""
|
||||
pyproject = tmp_path / "pyproject.toml"
|
||||
pyproject.write_text("[project]\nname = 'test'\n")
|
||||
subdir = tmp_path / "src" / "app"
|
||||
subdir.mkdir(parents=True)
|
||||
monkeypatch.chdir(subdir)
|
||||
|
||||
result = find_pyproject()
|
||||
assert result == pyproject
|
||||
|
||||
def test_find_pyproject_not_found(self, tmp_path, monkeypatch):
|
||||
"""Returns None when no pyproject.toml exists."""
|
||||
monkeypatch.chdir(tmp_path)
|
||||
result = find_pyproject()
|
||||
assert result is None
|
||||
|
||||
def test_load_pyproject_returns_tool_config(self, tmp_path, monkeypatch):
|
||||
"""load_pyproject returns the [tool.fastapi-toolsets] section."""
|
||||
pyproject = tmp_path / "pyproject.toml"
|
||||
pyproject.write_text(
|
||||
'[tool.fastapi-toolsets]\nfixtures = "app.fixtures:registry"\n'
|
||||
)
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
result = load_pyproject()
|
||||
assert result == {"fixtures": "app.fixtures:registry"}
|
||||
|
||||
def test_load_pyproject_empty_when_no_file(self, tmp_path, monkeypatch):
|
||||
"""Returns empty dict when no pyproject.toml exists."""
|
||||
monkeypatch.chdir(tmp_path)
|
||||
result = load_pyproject()
|
||||
assert result == {}
|
||||
|
||||
def test_load_pyproject_empty_when_no_tool_section(self, tmp_path, monkeypatch):
|
||||
"""Returns empty dict when no [tool.fastapi-toolsets] section."""
|
||||
pyproject = tmp_path / "pyproject.toml"
|
||||
pyproject.write_text("[project]\nname = 'test'\n")
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
result = load_pyproject()
|
||||
assert result == {}
|
||||
|
||||
def test_load_pyproject_invalid_toml(self, tmp_path, monkeypatch):
|
||||
"""Returns empty dict when pyproject.toml is invalid."""
|
||||
pyproject = tmp_path / "pyproject.toml"
|
||||
pyproject.write_text("invalid toml {{{")
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
result = load_pyproject()
|
||||
assert result == {}
|
||||
|
||||
|
||||
class TestImportFromString:
|
||||
"""Tests for import_from_string function."""
|
||||
|
||||
def test_import_valid_path(self):
|
||||
"""Import valid module:attribute path."""
|
||||
result = import_from_string("fastapi_toolsets.fixtures:FixtureRegistry")
|
||||
assert result is FixtureRegistry
|
||||
|
||||
def test_import_without_colon_raises_error(self):
|
||||
"""Import path without colon raises error."""
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
import_from_string("fastapi_toolsets.fixtures.FixtureRegistry")
|
||||
assert "Expected format: 'module:attribute'" in str(exc_info.value)
|
||||
|
||||
def test_import_nonexistent_module_raises_error(self):
|
||||
"""Import nonexistent module raises error."""
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
import_from_string("nonexistent.module:something")
|
||||
assert "Cannot import module" in str(exc_info.value)
|
||||
|
||||
def test_import_nonexistent_attribute_raises_error(self):
|
||||
"""Import nonexistent attribute raises error."""
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
import_from_string("fastapi_toolsets.fixtures:NonexistentClass")
|
||||
assert "has no attribute" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestGetConfigValue:
|
||||
"""Tests for get_config_value function."""
|
||||
|
||||
def test_get_existing_value(self, tmp_path, monkeypatch):
|
||||
"""Returns value when key exists."""
|
||||
pyproject = tmp_path / "pyproject.toml"
|
||||
pyproject.write_text('[tool.fastapi-toolsets]\nfixtures = "app:registry"\n')
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
result = get_config_value("fixtures")
|
||||
assert result == "app:registry"
|
||||
|
||||
def test_get_missing_value_returns_none(self, tmp_path, monkeypatch):
|
||||
"""Returns None when key is missing and not required."""
|
||||
pyproject = tmp_path / "pyproject.toml"
|
||||
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
result = get_config_value("fixtures")
|
||||
assert result is None
|
||||
|
||||
def test_get_missing_value_required_raises_error(self, tmp_path, monkeypatch):
|
||||
"""Raises error when key is missing and required."""
|
||||
pyproject = tmp_path / "pyproject.toml"
|
||||
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
get_config_value("fixtures", required=True)
|
||||
assert "No 'fixtures' configured" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestGetFixturesRegistry:
|
||||
"""Tests for get_fixtures_registry function."""
|
||||
|
||||
def test_raises_when_not_configured(self, tmp_path, monkeypatch):
|
||||
"""Raises error when fixtures not configured."""
|
||||
pyproject = tmp_path / "pyproject.toml"
|
||||
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
get_fixtures_registry()
|
||||
assert "No 'fixtures' configured" in str(exc_info.value)
|
||||
|
||||
def test_raises_when_not_registry_instance(self, tmp_path, monkeypatch):
|
||||
"""Raises error when imported object is not a FixtureRegistry."""
|
||||
pyproject = tmp_path / "pyproject.toml"
|
||||
pyproject.write_text(
|
||||
'[tool.fastapi-toolsets]\nfixtures = "my_fixtures:registry"\n'
|
||||
)
|
||||
|
||||
fixtures_file = tmp_path / "my_fixtures.py"
|
||||
fixtures_file.write_text("registry = 'not a registry'\n")
|
||||
|
||||
monkeypatch.chdir(tmp_path)
|
||||
if str(tmp_path) not in sys.path:
|
||||
sys.path.insert(0, str(tmp_path))
|
||||
|
||||
try:
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
get_fixtures_registry()
|
||||
assert "must be a FixtureRegistry instance" in str(exc_info.value)
|
||||
finally:
|
||||
if str(tmp_path) in sys.path:
|
||||
sys.path.remove(str(tmp_path))
|
||||
if "my_fixtures" in sys.modules:
|
||||
del sys.modules["my_fixtures"]
|
||||
|
||||
|
||||
class TestGetDbContext:
|
||||
"""Tests for get_db_context function."""
|
||||
|
||||
def test_raises_when_not_configured(self, tmp_path, monkeypatch):
|
||||
"""Raises error when db_context not configured."""
|
||||
pyproject = tmp_path / "pyproject.toml"
|
||||
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
get_db_context()
|
||||
assert "No 'db_context' configured" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestGetCustomCli:
|
||||
"""Tests for get_custom_cli function."""
|
||||
|
||||
def test_returns_none_when_not_configured(self, tmp_path, monkeypatch):
|
||||
"""Returns None when custom_cli not configured."""
|
||||
pyproject = tmp_path / "pyproject.toml"
|
||||
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
result = get_custom_cli()
|
||||
assert result is None
|
||||
|
||||
def test_raises_when_not_typer_instance(self, tmp_path, monkeypatch):
|
||||
"""Raises error when imported object is not a Typer instance."""
|
||||
pyproject = tmp_path / "pyproject.toml"
|
||||
pyproject.write_text('[tool.fastapi-toolsets]\ncustom_cli = "my_cli:cli"\n')
|
||||
|
||||
cli_file = tmp_path / "my_cli.py"
|
||||
cli_file.write_text("cli = 'not a typer'\n")
|
||||
|
||||
monkeypatch.chdir(tmp_path)
|
||||
if str(tmp_path) not in sys.path:
|
||||
sys.path.insert(0, str(tmp_path))
|
||||
|
||||
try:
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
get_custom_cli()
|
||||
assert "must be a Typer instance" in str(exc_info.value)
|
||||
finally:
|
||||
if str(tmp_path) in sys.path:
|
||||
sys.path.remove(str(tmp_path))
|
||||
if "my_cli" in sys.modules:
|
||||
del sys.modules["my_cli"]
|
||||
|
||||
|
||||
class TestCliApp:
|
||||
"""Tests for CLI application."""
|
||||
|
||||
def test_cli_help(self, tmp_path, monkeypatch):
|
||||
"""CLI shows help without fixtures."""
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
# Need to reload the module to pick up new cwd
|
||||
import importlib
|
||||
|
||||
from fastapi_toolsets.cli import app
|
||||
|
||||
importlib.reload(app)
|
||||
|
||||
result = runner.invoke(app.cli, ["--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "CLI utilities for FastAPI projects" in result.output
|
||||
|
||||
|
||||
class TestFixturesCli:
|
||||
"""Tests for fixtures CLI commands."""
|
||||
|
||||
@pytest.fixture
|
||||
def cli_env(self, tmp_path, monkeypatch):
|
||||
"""Set up CLI environment with fixtures config."""
|
||||
# Create pyproject.toml
|
||||
pyproject = tmp_path / "pyproject.toml"
|
||||
pyproject.write_text(
|
||||
"[tool.fastapi-toolsets]\n"
|
||||
'fixtures = "fixtures:registry"\n'
|
||||
'db_context = "db:get_session"\n'
|
||||
)
|
||||
|
||||
# Create fixtures module
|
||||
fixtures_file = tmp_path / "fixtures.py"
|
||||
fixtures_file.write_text(
|
||||
"from fastapi_toolsets.fixtures import FixtureRegistry, Context\n"
|
||||
"\n"
|
||||
"registry = FixtureRegistry()\n"
|
||||
"\n"
|
||||
"@registry.register(contexts=[Context.BASE])\n"
|
||||
"def roles():\n"
|
||||
' return [{"id": 1, "name": "admin"}, {"id": 2, "name": "user"}]\n'
|
||||
"\n"
|
||||
'@registry.register(depends_on=["roles"], contexts=[Context.TESTING])\n'
|
||||
"def users():\n"
|
||||
' return [{"id": 1, "name": "alice", "role_id": 1}]\n'
|
||||
)
|
||||
|
||||
# Create db module
|
||||
db_file = tmp_path / "db.py"
|
||||
db_file.write_text(
|
||||
"from contextlib import asynccontextmanager\n"
|
||||
"\n"
|
||||
"@asynccontextmanager\n"
|
||||
"async def get_session():\n"
|
||||
" yield None\n"
|
||||
)
|
||||
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
# Add tmp_path to sys.path for imports
|
||||
if str(tmp_path) not in sys.path:
|
||||
sys.path.insert(0, str(tmp_path))
|
||||
|
||||
# Reload the CLI module to pick up new config
|
||||
import importlib
|
||||
|
||||
from fastapi_toolsets.cli import app
|
||||
|
||||
importlib.reload(app)
|
||||
|
||||
yield tmp_path, app.cli
|
||||
|
||||
# Cleanup
|
||||
if str(tmp_path) in sys.path:
|
||||
sys.path.remove(str(tmp_path))
|
||||
|
||||
def test_fixtures_list(self, cli_env):
|
||||
"""fixtures list shows registered fixtures."""
|
||||
tmp_path, cli = cli_env
|
||||
result = runner.invoke(cli, ["fixtures", "list"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "roles" in result.output
|
||||
assert "users" in result.output
|
||||
assert "Total: 2 fixture(s)" in result.output
|
||||
|
||||
def test_fixtures_list_with_context(self, cli_env):
|
||||
"""fixtures list --context filters by context."""
|
||||
tmp_path, cli = cli_env
|
||||
result = runner.invoke(cli, ["fixtures", "list", "--context", "base"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "roles" in result.output
|
||||
assert "users" not in result.output
|
||||
assert "Total: 1 fixture(s)" in result.output
|
||||
|
||||
def test_fixtures_load_dry_run(self, cli_env):
|
||||
"""fixtures load --dry-run shows what would be loaded."""
|
||||
tmp_path, cli = cli_env
|
||||
result = runner.invoke(cli, ["fixtures", "load", "base", "--dry-run"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Fixtures to load" in result.output
|
||||
assert "roles" in result.output
|
||||
assert "[Dry run - no changes made]" in result.output
|
||||
|
||||
def test_fixtures_load_invalid_strategy(self, cli_env):
|
||||
"""fixtures load with invalid strategy shows error."""
|
||||
tmp_path, cli = cli_env
|
||||
result = runner.invoke(
|
||||
cli, ["fixtures", "load", "base", "--strategy", "invalid"]
|
||||
)
|
||||
|
||||
assert result.exit_code != 0
|
||||
|
||||
|
||||
class TestCliWithoutFixturesConfig:
|
||||
"""Tests for CLI when fixtures is not configured."""
|
||||
|
||||
def test_no_fixtures_command(self, tmp_path, monkeypatch):
|
||||
"""fixtures command is not available when not configured."""
|
||||
# Create pyproject.toml without fixtures
|
||||
pyproject = tmp_path / "pyproject.toml"
|
||||
pyproject.write_text('[project]\nname = "test"\n')
|
||||
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
# Reload the CLI module
|
||||
import importlib
|
||||
|
||||
from fastapi_toolsets.cli import app
|
||||
|
||||
importlib.reload(app)
|
||||
|
||||
result = runner.invoke(app.cli, ["--help"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "fixtures" not in result.output
|
||||
|
||||
|
||||
class TestCustomCliConfig:
|
||||
"""Tests for custom CLI configuration."""
|
||||
|
||||
def test_cli_with_custom_cli(self, tmp_path, monkeypatch):
|
||||
"""CLI uses custom Typer instance when configured."""
|
||||
import typer
|
||||
|
||||
# Create pyproject.toml with custom_cli config
|
||||
pyproject = tmp_path / "pyproject.toml"
|
||||
pyproject.write_text('[tool.fastapi-toolsets]\ncustom_cli = "my_cli:cli"\n')
|
||||
|
||||
# Create custom CLI module with its own Typer and commands
|
||||
cli_file = tmp_path / "my_cli.py"
|
||||
cli_file.write_text(
|
||||
"import typer\n"
|
||||
"\n"
|
||||
"cli = typer.Typer(name='my-app', help='My custom CLI')\n"
|
||||
"\n"
|
||||
"@cli.command()\n"
|
||||
"def hello():\n"
|
||||
' print("Hello from custom CLI!")\n'
|
||||
)
|
||||
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
# Add tmp_path to sys.path for imports
|
||||
if str(tmp_path) not in sys.path:
|
||||
sys.path.insert(0, str(tmp_path))
|
||||
|
||||
# Remove my_cli from sys.modules if it was previously loaded
|
||||
if "my_cli" in sys.modules:
|
||||
del sys.modules["my_cli"]
|
||||
|
||||
# Reload the CLI module to pick up new config
|
||||
import importlib
|
||||
|
||||
from fastapi_toolsets.cli import app
|
||||
|
||||
importlib.reload(app)
|
||||
|
||||
try:
|
||||
# Verify custom CLI is used
|
||||
assert isinstance(app.cli, typer.Typer)
|
||||
|
||||
result = runner.invoke(app.cli, ["--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "My custom CLI" in result.output
|
||||
assert "hello" in result.output
|
||||
|
||||
result = runner.invoke(app.cli, ["hello"])
|
||||
assert result.exit_code == 0
|
||||
assert "Hello from custom CLI!" in result.output
|
||||
finally:
|
||||
if str(tmp_path) in sys.path:
|
||||
sys.path.remove(str(tmp_path))
|
||||
if "my_cli" in sys.modules:
|
||||
del sys.modules["my_cli"]
|
||||
|
||||
def test_custom_cli_with_fixtures(self, tmp_path, monkeypatch):
|
||||
"""Custom CLI gets fixtures command added when configured."""
|
||||
# Create pyproject.toml with both custom_cli and fixtures
|
||||
pyproject = tmp_path / "pyproject.toml"
|
||||
pyproject.write_text(
|
||||
"[tool.fastapi-toolsets]\n"
|
||||
'custom_cli = "my_cli:cli"\n'
|
||||
'fixtures = "fixtures:registry"\n'
|
||||
'db_context = "db:get_session"\n'
|
||||
)
|
||||
|
||||
# Create custom CLI module
|
||||
cli_file = tmp_path / "my_cli.py"
|
||||
cli_file.write_text(
|
||||
"import typer\n"
|
||||
"\n"
|
||||
"cli = typer.Typer(name='my-app', help='My custom CLI')\n"
|
||||
"\n"
|
||||
"@cli.command()\n"
|
||||
"def hello():\n"
|
||||
' print("Hello!")\n'
|
||||
)
|
||||
|
||||
# Create fixtures module
|
||||
fixtures_file = tmp_path / "fixtures.py"
|
||||
fixtures_file.write_text(
|
||||
"from fastapi_toolsets.fixtures import FixtureRegistry\n"
|
||||
"\n"
|
||||
"registry = FixtureRegistry()\n"
|
||||
)
|
||||
|
||||
# Create db module
|
||||
db_file = tmp_path / "db.py"
|
||||
db_file.write_text(
|
||||
"from contextlib import asynccontextmanager\n"
|
||||
"\n"
|
||||
"@asynccontextmanager\n"
|
||||
"async def get_session():\n"
|
||||
" yield None\n"
|
||||
)
|
||||
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
if str(tmp_path) not in sys.path:
|
||||
sys.path.insert(0, str(tmp_path))
|
||||
|
||||
for mod in ["my_cli", "fixtures", "db"]:
|
||||
if mod in sys.modules:
|
||||
del sys.modules[mod]
|
||||
|
||||
import importlib
|
||||
|
||||
from fastapi_toolsets.cli import app
|
||||
|
||||
importlib.reload(app)
|
||||
|
||||
try:
|
||||
result = runner.invoke(app.cli, ["--help"])
|
||||
assert result.exit_code == 0
|
||||
# Should have both custom command and fixtures
|
||||
assert "hello" in result.output
|
||||
assert "fixtures" in result.output
|
||||
finally:
|
||||
if str(tmp_path) in sys.path:
|
||||
sys.path.remove(str(tmp_path))
|
||||
for mod in ["my_cli", "fixtures", "db"]:
|
||||
if mod in sys.modules:
|
||||
del sys.modules[mod]
|
||||
|
||||
|
||||
class TestAsyncCommand:
|
||||
"""Tests for async_command decorator."""
|
||||
|
||||
def test_async_command_runs_coroutine(self):
|
||||
"""async_command runs async function synchronously."""
|
||||
|
||||
@async_command
|
||||
async def async_func(value: int) -> int:
|
||||
return value * 2
|
||||
|
||||
result = async_func(21)
|
||||
assert result == 42
|
||||
|
||||
def test_async_command_preserves_signature(self):
|
||||
"""async_command preserves function signature."""
|
||||
|
||||
@async_command
|
||||
async def async_func(name: str, count: int = 1) -> str:
|
||||
return f"{name} x {count}"
|
||||
|
||||
result = async_func("test", count=3)
|
||||
assert result == "test x 3"
|
||||
|
||||
def test_async_command_preserves_docstring(self):
|
||||
"""async_command preserves function docstring."""
|
||||
|
||||
@async_command
|
||||
async def async_func() -> None:
|
||||
"""This is a docstring."""
|
||||
pass
|
||||
|
||||
assert async_func.__doc__ == """This is a docstring."""
|
||||
|
||||
def test_async_command_preserves_name(self):
|
||||
"""async_command preserves function name."""
|
||||
|
||||
@async_command
|
||||
async def my_async_function() -> None:
|
||||
pass
|
||||
|
||||
assert my_async_function.__name__ == "my_async_function"
|
||||
2159
tests/test_crud.py
2159
tests/test_crud.py
File diff suppressed because it is too large
Load Diff
1439
tests/test_crud_search.py
Normal file
1439
tests/test_crud_search.py
Normal file
File diff suppressed because it is too large
Load Diff
191
tests/test_db.py
191
tests/test_db.py
@@ -1,17 +1,28 @@
|
||||
"""Tests for fastapi_toolsets.db module."""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from fastapi_toolsets.db import (
|
||||
LockMode,
|
||||
cleanup_tables,
|
||||
create_database,
|
||||
create_db_context,
|
||||
create_db_dependency,
|
||||
get_transaction,
|
||||
lock_tables,
|
||||
wait_for_row_change,
|
||||
)
|
||||
from fastapi_toolsets.exceptions import NotFoundError
|
||||
from fastapi_toolsets.pytest import create_db_session
|
||||
|
||||
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User
|
||||
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
|
||||
|
||||
|
||||
class TestCreateDbDependency:
|
||||
@@ -241,3 +252,181 @@ class TestLockTables:
|
||||
|
||||
result = await RoleCrud.first(db_session, [Role.name == "lock_rollback_role"])
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestWaitForRowChange:
|
||||
"""Tests for wait_for_row_change polling function."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_detects_update(self, db_session: AsyncSession, engine):
|
||||
"""Returns updated instance when a column value changes."""
|
||||
role = Role(name="watch_role")
|
||||
db_session.add(role)
|
||||
await db_session.commit()
|
||||
|
||||
async def update_later():
|
||||
await asyncio.sleep(0.15)
|
||||
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
async with factory() as other:
|
||||
r = await other.get(Role, role.id)
|
||||
assert r is not None
|
||||
r.name = "updated_role"
|
||||
await other.commit()
|
||||
|
||||
update_task = asyncio.create_task(update_later())
|
||||
result = await wait_for_row_change(db_session, Role, role.id, interval=0.05)
|
||||
await update_task
|
||||
|
||||
assert result.name == "updated_role"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_watches_specific_columns(self, db_session: AsyncSession, engine):
|
||||
"""Only triggers on changes to specified columns."""
|
||||
user = User(username="testuser", email="test@example.com")
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
|
||||
async def update_later():
|
||||
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
# First: change email (not watched) — should not trigger
|
||||
await asyncio.sleep(0.15)
|
||||
async with factory() as other:
|
||||
u = await other.get(User, user.id)
|
||||
assert u is not None
|
||||
u.email = "new@example.com"
|
||||
await other.commit()
|
||||
# Second: change username (watched) — should trigger
|
||||
await asyncio.sleep(0.15)
|
||||
async with factory() as other:
|
||||
u = await other.get(User, user.id)
|
||||
assert u is not None
|
||||
u.username = "newuser"
|
||||
await other.commit()
|
||||
|
||||
update_task = asyncio.create_task(update_later())
|
||||
result = await wait_for_row_change(
|
||||
db_session, User, user.id, columns=["username"], interval=0.05
|
||||
)
|
||||
await update_task
|
||||
|
||||
assert result.username == "newuser"
|
||||
assert result.email == "new@example.com"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_nonexistent_row_raises(self, db_session: AsyncSession):
|
||||
"""Raises NotFoundError when the row does not exist."""
|
||||
fake_id = uuid.uuid4()
|
||||
with pytest.raises(NotFoundError, match="not found"):
|
||||
await wait_for_row_change(db_session, Role, fake_id, interval=0.05)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_timeout_raises(self, db_session: AsyncSession):
|
||||
"""Raises TimeoutError when no change is detected within timeout."""
|
||||
role = Role(name="timeout_role")
|
||||
db_session.add(role)
|
||||
await db_session.commit()
|
||||
|
||||
with pytest.raises(TimeoutError):
|
||||
await wait_for_row_change(
|
||||
db_session, Role, role.id, interval=0.05, timeout=0.2
|
||||
)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_deleted_row_raises(self, db_session: AsyncSession, engine):
|
||||
"""Raises NotFoundError when the row is deleted during polling."""
|
||||
role = Role(name="delete_role")
|
||||
db_session.add(role)
|
||||
await db_session.commit()
|
||||
|
||||
async def delete_later():
|
||||
await asyncio.sleep(0.15)
|
||||
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
async with factory() as other:
|
||||
r = await other.get(Role, role.id)
|
||||
await other.delete(r)
|
||||
await other.commit()
|
||||
|
||||
delete_task = asyncio.create_task(delete_later())
|
||||
with pytest.raises(NotFoundError):
|
||||
await wait_for_row_change(db_session, Role, role.id, interval=0.05)
|
||||
await delete_task
|
||||
|
||||
|
||||
class TestCreateDatabase:
|
||||
"""Tests for create_database."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_creates_database(self):
|
||||
"""Database is created by create_database."""
|
||||
target_url = (
|
||||
make_url(DATABASE_URL)
|
||||
.set(database="test_create_db_general")
|
||||
.render_as_string(hide_password=False)
|
||||
)
|
||||
expected_db: str = make_url(target_url).database # type: ignore[assignment]
|
||||
|
||||
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||
try:
|
||||
async with engine.connect() as conn:
|
||||
await conn.execute(text(f"DROP DATABASE IF EXISTS {expected_db}"))
|
||||
|
||||
await create_database(db_name=expected_db, server_url=DATABASE_URL)
|
||||
|
||||
async with engine.connect() as conn:
|
||||
result = await conn.execute(
|
||||
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||
{"name": expected_db},
|
||||
)
|
||||
assert result.scalar() == 1
|
||||
|
||||
# Cleanup
|
||||
async with engine.connect() as conn:
|
||||
await conn.execute(text(f"DROP DATABASE IF EXISTS {expected_db}"))
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
class TestCleanupTables:
|
||||
"""Tests for cleanup_tables helper."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_truncates_all_tables(self):
|
||||
"""All table rows are removed after cleanup_tables."""
|
||||
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||
role = Role(id=uuid.uuid4(), name="cleanup_role")
|
||||
session.add(role)
|
||||
await session.flush()
|
||||
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
username="cleanup_user",
|
||||
email="cleanup@test.com",
|
||||
role_id=role.id,
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
|
||||
# Verify rows exist
|
||||
roles_count = await RoleCrud.count(session)
|
||||
users_count = await UserCrud.count(session)
|
||||
assert roles_count == 1
|
||||
assert users_count == 1
|
||||
|
||||
await cleanup_tables(session, Base)
|
||||
|
||||
# Verify tables are empty
|
||||
roles_count = await RoleCrud.count(session)
|
||||
users_count = await UserCrud.count(session)
|
||||
assert roles_count == 0
|
||||
assert users_count == 0
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_noop_for_empty_metadata(self):
|
||||
"""cleanup_tables does not raise when metadata has no tables."""
|
||||
|
||||
class EmptyBase(DeclarativeBase):
|
||||
pass
|
||||
|
||||
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||
# Should not raise
|
||||
await cleanup_tables(session, EmptyBase)
|
||||
|
||||
277
tests/test_dependencies.py
Normal file
277
tests/test_dependencies.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""Tests for fastapi_toolsets.dependencies module."""
|
||||
|
||||
import inspect
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated, Any, cast
|
||||
|
||||
import pytest
|
||||
from fastapi.params import Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from fastapi_toolsets.dependencies import (
|
||||
BodyDependency,
|
||||
PathDependency,
|
||||
_unwrap_session_dep,
|
||||
)
|
||||
|
||||
from .conftest import Role, RoleCreate, RoleCrud, User
|
||||
|
||||
|
||||
async def mock_get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Mock session dependency for testing."""
|
||||
yield None
|
||||
|
||||
|
||||
MockSessionDep = Annotated[AsyncSession, Depends(mock_get_db)]
|
||||
|
||||
|
||||
class TestUnwrapSessionDep:
|
||||
def test_plain_callable_returned_as_is(self):
|
||||
"""Plain callable is returned unchanged."""
|
||||
assert _unwrap_session_dep(mock_get_db) is mock_get_db
|
||||
|
||||
def test_annotated_with_depends_unwrapped(self):
|
||||
"""Annotated form with Depends is unwrapped to the plain callable."""
|
||||
assert _unwrap_session_dep(MockSessionDep) is mock_get_db
|
||||
|
||||
def test_annotated_without_depends_returned_as_is(self):
|
||||
"""Annotated form with no Depends falls back to returning session_dep as-is."""
|
||||
annotated_no_dep = Annotated[AsyncSession, "not_a_depends"]
|
||||
assert _unwrap_session_dep(annotated_no_dep) is annotated_no_dep
|
||||
|
||||
|
||||
class TestPathDependency:
|
||||
"""Tests for PathDependency factory."""
|
||||
|
||||
def test_returns_depends_instance(self):
|
||||
"""PathDependency returns a Depends instance."""
|
||||
dep = PathDependency(Role, Role.id, session_dep=mock_get_db)
|
||||
assert isinstance(dep, Depends)
|
||||
|
||||
def test_signature_has_default_param_name(self):
|
||||
"""PathDependency uses model_field as default param name."""
|
||||
dep = cast(Any, PathDependency(Role, Role.id, session_dep=mock_get_db))
|
||||
func = dep.dependency
|
||||
|
||||
sig = inspect.signature(func)
|
||||
params = list(sig.parameters.keys())
|
||||
|
||||
assert "role_id" in params
|
||||
assert "session" in params
|
||||
|
||||
def test_signature_has_correct_type_annotation(self):
|
||||
"""PathDependency uses field's python type for annotation."""
|
||||
dep = cast(Any, PathDependency(Role, Role.id, session_dep=mock_get_db))
|
||||
func = dep.dependency
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
||||
assert sig.parameters["role_id"].annotation == uuid.UUID
|
||||
assert sig.parameters["session"].annotation == AsyncSession
|
||||
|
||||
def test_signature_session_has_depends_default(self):
|
||||
"""PathDependency session param has Depends as default."""
|
||||
dep = cast(Any, PathDependency(Role, Role.id, session_dep=mock_get_db))
|
||||
func = dep.dependency
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
||||
assert isinstance(sig.parameters["session"].default, Depends)
|
||||
|
||||
def test_custom_param_name_in_signature(self):
|
||||
"""PathDependency uses custom param_name in signature."""
|
||||
dep = cast(
|
||||
Any,
|
||||
PathDependency(
|
||||
Role, Role.id, session_dep=mock_get_db, param_name="role_uuid"
|
||||
),
|
||||
)
|
||||
func = dep.dependency
|
||||
|
||||
sig = inspect.signature(func)
|
||||
params = list(sig.parameters.keys())
|
||||
|
||||
assert "role_uuid" in params
|
||||
assert "id" not in params
|
||||
|
||||
def test_string_field_type(self):
|
||||
"""PathDependency handles string field types."""
|
||||
dep = cast(Any, PathDependency(User, User.username, session_dep=mock_get_db))
|
||||
func = dep.dependency
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
||||
assert sig.parameters["user_username"].annotation is str
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_dependency_fetches_object(self, db_session):
|
||||
"""PathDependency inner function fetches object from database."""
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="test_role"))
|
||||
|
||||
dep = cast(Any, PathDependency(Role, Role.id, session_dep=mock_get_db))
|
||||
func = dep.dependency
|
||||
|
||||
result = await func(session=db_session, role_id=role.id)
|
||||
|
||||
assert result.id == role.id
|
||||
assert result.name == "test_role"
|
||||
|
||||
def test_annotated_session_dep_returns_depends_instance(self):
|
||||
"""PathDependency accepts Annotated[AsyncSession, Depends(...)] form."""
|
||||
dep = PathDependency(Role, Role.id, session_dep=MockSessionDep)
|
||||
assert isinstance(dep, Depends)
|
||||
|
||||
def test_annotated_session_dep_signature(self):
|
||||
"""PathDependency with Annotated session_dep produces a valid signature."""
|
||||
dep = cast(Any, PathDependency(Role, Role.id, session_dep=MockSessionDep))
|
||||
sig = inspect.signature(dep.dependency)
|
||||
|
||||
assert "role_id" in sig.parameters
|
||||
assert "session" in sig.parameters
|
||||
assert isinstance(sig.parameters["session"].default, Depends)
|
||||
|
||||
def test_annotated_session_dep_unwraps_callable(self):
|
||||
"""PathDependency with Annotated form uses the underlying callable, not the Annotated type."""
|
||||
dep = cast(Any, PathDependency(Role, Role.id, session_dep=MockSessionDep))
|
||||
sig = inspect.signature(dep.dependency)
|
||||
|
||||
inner_dep = sig.parameters["session"].default
|
||||
assert inner_dep.dependency is mock_get_db
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_annotated_session_dep_fetches_object(self, db_session):
|
||||
"""PathDependency with Annotated session_dep correctly fetches object from database."""
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="annotated_role"))
|
||||
|
||||
dep = cast(Any, PathDependency(Role, Role.id, session_dep=MockSessionDep))
|
||||
result = await dep.dependency(session=db_session, role_id=role.id)
|
||||
|
||||
assert result.id == role.id
|
||||
assert result.name == "annotated_role"
|
||||
|
||||
|
||||
class TestBodyDependency:
|
||||
"""Tests for BodyDependency factory."""
|
||||
|
||||
def test_returns_depends_instance(self):
|
||||
"""BodyDependency returns a Depends instance."""
|
||||
dep = BodyDependency(
|
||||
Role, Role.id, session_dep=mock_get_db, body_field="role_id"
|
||||
)
|
||||
assert isinstance(dep, Depends)
|
||||
|
||||
def test_signature_has_body_field_as_param(self):
|
||||
"""BodyDependency uses body_field as param name."""
|
||||
dep = cast(
|
||||
Any,
|
||||
BodyDependency(
|
||||
Role, Role.id, session_dep=mock_get_db, body_field="role_id"
|
||||
),
|
||||
)
|
||||
func = dep.dependency
|
||||
|
||||
sig = inspect.signature(func)
|
||||
params = list(sig.parameters.keys())
|
||||
|
||||
assert "role_id" in params
|
||||
assert "session" in params
|
||||
|
||||
def test_signature_has_correct_type_annotation(self):
|
||||
"""BodyDependency uses field's python type for annotation."""
|
||||
dep = cast(
|
||||
Any,
|
||||
BodyDependency(
|
||||
Role, Role.id, session_dep=mock_get_db, body_field="role_id"
|
||||
),
|
||||
)
|
||||
func = dep.dependency
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
||||
assert sig.parameters["role_id"].annotation == uuid.UUID
|
||||
assert sig.parameters["session"].annotation == AsyncSession
|
||||
|
||||
def test_signature_session_has_depends_default(self):
|
||||
"""BodyDependency session param has Depends as default."""
|
||||
dep = cast(
|
||||
Any,
|
||||
BodyDependency(
|
||||
Role, Role.id, session_dep=mock_get_db, body_field="role_id"
|
||||
),
|
||||
)
|
||||
func = dep.dependency
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
||||
assert isinstance(sig.parameters["session"].default, Depends)
|
||||
|
||||
def test_different_body_field_name(self):
|
||||
"""BodyDependency can use any body_field name."""
|
||||
dep = cast(
|
||||
Any,
|
||||
BodyDependency(
|
||||
User, User.id, session_dep=mock_get_db, body_field="user_uuid"
|
||||
),
|
||||
)
|
||||
func = dep.dependency
|
||||
|
||||
sig = inspect.signature(func)
|
||||
params = list(sig.parameters.keys())
|
||||
|
||||
assert "user_uuid" in params
|
||||
assert "id" not in params
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_dependency_fetches_object(self, db_session):
|
||||
"""BodyDependency inner function fetches object from database."""
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="body_test_role"))
|
||||
|
||||
dep = cast(
|
||||
Any,
|
||||
BodyDependency(
|
||||
Role, Role.id, session_dep=mock_get_db, body_field="role_id"
|
||||
),
|
||||
)
|
||||
func = dep.dependency
|
||||
|
||||
result = await func(session=db_session, role_id=role.id)
|
||||
|
||||
assert result.id == role.id
|
||||
assert result.name == "body_test_role"
|
||||
|
||||
def test_annotated_session_dep_returns_depends_instance(self):
|
||||
"""BodyDependency accepts Annotated[AsyncSession, Depends(...)] form."""
|
||||
dep = BodyDependency(
|
||||
Role, Role.id, session_dep=MockSessionDep, body_field="role_id"
|
||||
)
|
||||
assert isinstance(dep, Depends)
|
||||
|
||||
def test_annotated_session_dep_unwraps_callable(self):
|
||||
"""BodyDependency with Annotated form uses the underlying callable, not the Annotated type."""
|
||||
dep = cast(
|
||||
Any,
|
||||
BodyDependency(
|
||||
Role, Role.id, session_dep=MockSessionDep, body_field="role_id"
|
||||
),
|
||||
)
|
||||
sig = inspect.signature(dep.dependency)
|
||||
|
||||
inner_dep = sig.parameters["session"].default
|
||||
assert inner_dep.dependency is mock_get_db
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_annotated_session_dep_fetches_object(self, db_session):
|
||||
"""BodyDependency with Annotated session_dep correctly fetches object from database."""
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="body_annotated_role"))
|
||||
|
||||
dep = cast(
|
||||
Any,
|
||||
BodyDependency(
|
||||
Role, Role.id, session_dep=MockSessionDep, body_field="role_id"
|
||||
),
|
||||
)
|
||||
result = await dep.dependency(session=db_session, role_id=role.id)
|
||||
|
||||
assert result.id == role.id
|
||||
assert result.name == "body_annotated_role"
|
||||
486
tests/test_example_pagination_search.py
Normal file
486
tests/test_example_pagination_search.py
Normal file
@@ -0,0 +1,486 @@
|
||||
"""Live test for the docs/examples/pagination-search.md example.
|
||||
|
||||
Spins up the exact FastAPI app described in the example (sourced from
|
||||
docs_src/examples/pagination_search/) and exercises it through a real HTTP
|
||||
client against a real PostgreSQL database.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from docs_src.examples.pagination_search.db import get_db
|
||||
from docs_src.examples.pagination_search.models import Article, Base, Category
|
||||
from docs_src.examples.pagination_search.routes import router
|
||||
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
||||
from fastapi_toolsets.pytest import create_db_session
|
||||
|
||||
from .conftest import DATABASE_URL
|
||||
|
||||
|
||||
def build_app(session: AsyncSession) -> FastAPI:
|
||||
app = FastAPI()
|
||||
init_exceptions_handlers(app)
|
||||
|
||||
async def override_get_db():
|
||||
yield session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.include_router(router)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def ex_db_session():
|
||||
"""Isolated session for the example models (separate tables from conftest)."""
|
||||
async with create_db_session(DATABASE_URL, Base) as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client(ex_db_session: AsyncSession):
|
||||
app = build_app(ex_db_session)
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as ac:
|
||||
yield ac
|
||||
|
||||
|
||||
async def seed(session: AsyncSession):
|
||||
"""Insert representative fixture data."""
|
||||
python = Category(name="python")
|
||||
backend = Category(name="backend")
|
||||
session.add_all([python, backend])
|
||||
await session.flush()
|
||||
|
||||
now = datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc)
|
||||
session.add_all(
|
||||
[
|
||||
Article(
|
||||
title="FastAPI tips",
|
||||
body="Ten useful tips for FastAPI.",
|
||||
status="published",
|
||||
published=True,
|
||||
category_id=python.id,
|
||||
created_at=now,
|
||||
),
|
||||
Article(
|
||||
title="SQLAlchemy async",
|
||||
body="How to use async SQLAlchemy.",
|
||||
status="published",
|
||||
published=True,
|
||||
category_id=backend.id,
|
||||
created_at=now + datetime.timedelta(seconds=1),
|
||||
),
|
||||
Article(
|
||||
title="Draft notes",
|
||||
body="Work in progress.",
|
||||
status="draft",
|
||||
published=False,
|
||||
category_id=None,
|
||||
created_at=now + datetime.timedelta(seconds=2),
|
||||
),
|
||||
]
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
|
||||
class TestAppSessionDep:
|
||||
@pytest.mark.anyio
|
||||
async def test_get_db_yields_async_session(self):
|
||||
"""get_db yields a real AsyncSession when called directly."""
|
||||
from docs_src.examples.pagination_search.db import get_db
|
||||
|
||||
gen = get_db()
|
||||
session = await gen.__anext__()
|
||||
assert isinstance(session, AsyncSession)
|
||||
await session.close()
|
||||
|
||||
|
||||
class TestOffsetPagination:
|
||||
@pytest.mark.anyio
|
||||
async def test_returns_all_articles(self, client: AsyncClient, ex_db_session):
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/offset")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["pagination"]["total_count"] == 3
|
||||
assert len(body["data"]) == 3
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_pagination_page_size(self, client: AsyncClient, ex_db_session):
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/offset?items_per_page=2&page=1")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert len(body["data"]) == 2
|
||||
assert body["pagination"]["total_count"] == 3
|
||||
assert body["pagination"]["has_more"] is True
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_full_text_search(self, client: AsyncClient, ex_db_session):
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/offset?search=fastapi")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["pagination"]["total_count"] == 1
|
||||
assert body["data"][0]["title"] == "FastAPI tips"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_traverses_relationship(
|
||||
self, client: AsyncClient, ex_db_session
|
||||
):
|
||||
await seed(ex_db_session)
|
||||
|
||||
# "python" matches Category.name, not Article.title or body
|
||||
resp = await client.get("/articles/offset?search=python")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["pagination"]["total_count"] == 1
|
||||
assert body["data"][0]["title"] == "FastAPI tips"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_facet_filter_scalar(self, client: AsyncClient, ex_db_session):
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/offset?status=published")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["pagination"]["total_count"] == 2
|
||||
assert all(a["status"] == "published" for a in body["data"])
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_facet_filter_multi_value(self, client: AsyncClient, ex_db_session):
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/offset?status=published&status=draft")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["pagination"]["total_count"] == 3
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_filter_attributes_in_response(
|
||||
self, client: AsyncClient, ex_db_session
|
||||
):
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/offset")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
fa = body["filter_attributes"]
|
||||
assert set(fa["status"]) == {"draft", "published"}
|
||||
# "name" is unique across all facet fields — no prefix needed
|
||||
assert set(fa["name"]) == {"backend", "python"}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_filter_attributes_scoped_to_filter(
|
||||
self, client: AsyncClient, ex_db_session
|
||||
):
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/offset?status=published")
|
||||
|
||||
body = resp.json()
|
||||
# draft is filtered out → should not appear in filter_attributes
|
||||
assert "draft" not in body["filter_attributes"]["status"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_and_filter_combined(self, client: AsyncClient, ex_db_session):
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/offset?search=async&status=published")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["pagination"]["total_count"] == 1
|
||||
assert body["data"][0]["title"] == "SQLAlchemy async"
|
||||
|
||||
|
||||
class TestCursorPagination:
|
||||
@pytest.mark.anyio
|
||||
async def test_first_page(self, client: AsyncClient, ex_db_session):
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/cursor?items_per_page=2")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert len(body["data"]) == 2
|
||||
assert body["pagination"]["has_more"] is True
|
||||
assert body["pagination"]["next_cursor"] is not None
|
||||
assert body["pagination"]["prev_cursor"] is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_second_page(self, client: AsyncClient, ex_db_session):
|
||||
await seed(ex_db_session)
|
||||
|
||||
first = await client.get("/articles/cursor?items_per_page=2")
|
||||
next_cursor = first.json()["pagination"]["next_cursor"]
|
||||
|
||||
resp = await client.get(
|
||||
f"/articles/cursor?items_per_page=2&cursor={next_cursor}"
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert len(body["data"]) == 1
|
||||
assert body["pagination"]["has_more"] is False
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_facet_filter(self, client: AsyncClient, ex_db_session):
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/cursor?status=draft")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert len(body["data"]) == 1
|
||||
assert body["data"][0]["status"] == "draft"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_full_text_search(self, client: AsyncClient, ex_db_session):
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/cursor?search=sqlalchemy")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert len(body["data"]) == 1
|
||||
assert body["data"][0]["title"] == "SQLAlchemy async"
|
||||
|
||||
|
||||
class TestOffsetSorting:
|
||||
"""Tests for order_by / order query parameters on the offset endpoint."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_default_order_uses_created_at_asc(
|
||||
self, client: AsyncClient, ex_db_session
|
||||
):
|
||||
"""No order_by → default field (created_at) ASC."""
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/offset")
|
||||
|
||||
assert resp.status_code == 200
|
||||
titles = [a["title"] for a in resp.json()["data"]]
|
||||
assert titles == ["FastAPI tips", "SQLAlchemy async", "Draft notes"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_order_by_title_asc(self, client: AsyncClient, ex_db_session):
|
||||
"""order_by=title&order=asc returns alphabetical order."""
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/offset?order_by=title&order=asc")
|
||||
|
||||
assert resp.status_code == 200
|
||||
titles = [a["title"] for a in resp.json()["data"]]
|
||||
assert titles == ["Draft notes", "FastAPI tips", "SQLAlchemy async"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_order_by_title_desc(self, client: AsyncClient, ex_db_session):
|
||||
"""order_by=title&order=desc returns reverse alphabetical order."""
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/offset?order_by=title&order=desc")
|
||||
|
||||
assert resp.status_code == 200
|
||||
titles = [a["title"] for a in resp.json()["data"]]
|
||||
assert titles == ["SQLAlchemy async", "FastAPI tips", "Draft notes"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_order_by_created_at_desc(self, client: AsyncClient, ex_db_session):
|
||||
"""order_by=created_at&order=desc returns newest-first."""
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/offset?order_by=created_at&order=desc")
|
||||
|
||||
assert resp.status_code == 200
|
||||
titles = [a["title"] for a in resp.json()["data"]]
|
||||
assert titles == ["Draft notes", "SQLAlchemy async", "FastAPI tips"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_invalid_order_by_returns_422(
|
||||
self, client: AsyncClient, ex_db_session
|
||||
):
|
||||
"""Unknown order_by field returns 422 with SORT-422 error code."""
|
||||
resp = await client.get("/articles/offset?order_by=nonexistent_field")
|
||||
|
||||
assert resp.status_code == 422
|
||||
body = resp.json()
|
||||
assert body["error_code"] == "SORT-422"
|
||||
assert body["status"] == "FAIL"
|
||||
|
||||
|
||||
class TestCursorSorting:
|
||||
"""Tests for order_by / order query parameters on the cursor endpoint.
|
||||
|
||||
In cursor_paginate the cursor_column is always the primary sort; order_by
|
||||
acts as a secondary tiebreaker. With the seeded articles (all having unique
|
||||
created_at values) the overall ordering is always created_at ASC regardless
|
||||
of the order_by value — only the valid/invalid field check and the response
|
||||
shape are meaningful here.
|
||||
"""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_default_order_uses_created_at_asc(
|
||||
self, client: AsyncClient, ex_db_session
|
||||
):
|
||||
"""No order_by → default field (created_at) ASC."""
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/cursor")
|
||||
|
||||
assert resp.status_code == 200
|
||||
titles = [a["title"] for a in resp.json()["data"]]
|
||||
assert titles == ["FastAPI tips", "SQLAlchemy async", "Draft notes"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_order_by_title_asc_accepted(
|
||||
self, client: AsyncClient, ex_db_session
|
||||
):
|
||||
"""order_by=title is a valid field — request succeeds and returns all articles."""
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/cursor?order_by=title&order=asc")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()["data"]) == 3
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_order_by_title_desc_accepted(
|
||||
self, client: AsyncClient, ex_db_session
|
||||
):
|
||||
"""order_by=title&order=desc is valid — request succeeds and returns all articles."""
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/cursor?order_by=title&order=desc")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()["data"]) == 3
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_invalid_order_by_returns_422(
|
||||
self, client: AsyncClient, ex_db_session
|
||||
):
|
||||
"""Unknown order_by field returns 422 with SORT-422 error code."""
|
||||
resp = await client.get("/articles/cursor?order_by=nonexistent_field")
|
||||
|
||||
assert resp.status_code == 422
|
||||
body = resp.json()
|
||||
assert body["error_code"] == "SORT-422"
|
||||
assert body["status"] == "FAIL"
|
||||
|
||||
|
||||
class TestPaginateUnified:
|
||||
"""Tests for the unified GET /articles/ endpoint using paginate()."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_defaults_to_offset_pagination(
|
||||
self, client: AsyncClient, ex_db_session
|
||||
):
|
||||
"""Without pagination_type, defaults to offset pagination."""
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["pagination_type"] == "offset"
|
||||
assert "total_count" in body["pagination"]
|
||||
assert body["pagination"]["total_count"] == 3
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_explicit_offset_pagination(self, client: AsyncClient, ex_db_session):
|
||||
"""pagination_type=offset returns OffsetPagination metadata."""
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get(
|
||||
"/articles/?pagination_type=offset&page=1&items_per_page=2"
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["pagination_type"] == "offset"
|
||||
assert body["pagination"]["total_count"] == 3
|
||||
assert body["pagination"]["page"] == 1
|
||||
assert body["pagination"]["has_more"] is True
|
||||
assert len(body["data"]) == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cursor_pagination_type(self, client: AsyncClient, ex_db_session):
|
||||
"""pagination_type=cursor returns CursorPagination metadata."""
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/?pagination_type=cursor&items_per_page=2")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["pagination_type"] == "cursor"
|
||||
assert "next_cursor" in body["pagination"]
|
||||
assert "total_count" not in body["pagination"]
|
||||
assert body["pagination"]["has_more"] is True
|
||||
assert len(body["data"]) == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cursor_pagination_navigate_pages(
|
||||
self, client: AsyncClient, ex_db_session
|
||||
):
|
||||
"""Cursor from first page can be used to fetch the next page."""
|
||||
await seed(ex_db_session)
|
||||
|
||||
first = await client.get("/articles/?pagination_type=cursor&items_per_page=2")
|
||||
assert first.status_code == 200
|
||||
first_body = first.json()
|
||||
next_cursor = first_body["pagination"]["next_cursor"]
|
||||
assert next_cursor is not None
|
||||
|
||||
second = await client.get(
|
||||
f"/articles/?pagination_type=cursor&items_per_page=2&cursor={next_cursor}"
|
||||
)
|
||||
assert second.status_code == 200
|
||||
second_body = second.json()
|
||||
assert second_body["pagination_type"] == "cursor"
|
||||
assert second_body["pagination"]["has_more"] is False
|
||||
assert len(second_body["data"]) == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cursor_pagination_with_search(
|
||||
self, client: AsyncClient, ex_db_session
|
||||
):
|
||||
"""paginate() with cursor type respects search parameter."""
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/?pagination_type=cursor&search=fastapi")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["pagination_type"] == "cursor"
|
||||
assert len(body["data"]) == 1
|
||||
assert body["data"][0]["title"] == "FastAPI tips"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_offset_pagination_with_filter(
|
||||
self, client: AsyncClient, ex_db_session
|
||||
):
|
||||
"""paginate() with offset type respects filter_by parameter."""
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/?pagination_type=offset&status=published")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["pagination_type"] == "offset"
|
||||
assert body["pagination"]["total_count"] == 2
|
||||
@@ -2,12 +2,14 @@
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from fastapi_toolsets.exceptions import (
|
||||
ApiException,
|
||||
ConflictError,
|
||||
ForbiddenError,
|
||||
InvalidOrderFieldError,
|
||||
NotFoundError,
|
||||
UnauthorizedError,
|
||||
generate_error_responses,
|
||||
@@ -35,8 +37,8 @@ class TestApiException:
|
||||
assert error.api_error.msg == "I'm a teapot"
|
||||
assert str(error) == "I'm a teapot"
|
||||
|
||||
def test_custom_detail_message(self):
|
||||
"""Custom detail overrides default message."""
|
||||
def test_detail_overrides_msg_and_str(self):
|
||||
"""detail sets both str(exc) and api_error.msg; class-level msg is unchanged."""
|
||||
|
||||
class CustomError(ApiException):
|
||||
api_error = ApiError(
|
||||
@@ -46,8 +48,172 @@ class TestApiException:
|
||||
err_code="BAD-400",
|
||||
)
|
||||
|
||||
error = CustomError("Custom message")
|
||||
assert str(error) == "Custom message"
|
||||
error = CustomError("Widget not found")
|
||||
assert str(error) == "Widget not found"
|
||||
assert error.api_error.msg == "Widget not found"
|
||||
assert CustomError.api_error.msg == "Bad Request" # class unchanged
|
||||
|
||||
def test_desc_override(self):
|
||||
"""desc kwarg overrides api_error.desc on the instance only."""
|
||||
|
||||
class MyError(ApiException):
|
||||
api_error = ApiError(
|
||||
code=400, msg="Error", desc="Default.", err_code="ERR-400"
|
||||
)
|
||||
|
||||
err = MyError(desc="Custom desc.")
|
||||
assert err.api_error.desc == "Custom desc."
|
||||
assert MyError.api_error.desc == "Default." # class unchanged
|
||||
|
||||
def test_data_override(self):
|
||||
"""data kwarg sets api_error.data on the instance only."""
|
||||
|
||||
class MyError(ApiException):
|
||||
api_error = ApiError(
|
||||
code=400, msg="Error", desc="Default.", err_code="ERR-400"
|
||||
)
|
||||
|
||||
err = MyError(data={"key": "value"})
|
||||
assert err.api_error.data == {"key": "value"}
|
||||
assert MyError.api_error.data is None # class unchanged
|
||||
|
||||
def test_desc_and_data_override(self):
|
||||
"""detail, desc and data can all be overridden together."""
|
||||
|
||||
class MyError(ApiException):
|
||||
api_error = ApiError(
|
||||
code=400, msg="Error", desc="Default.", err_code="ERR-400"
|
||||
)
|
||||
|
||||
err = MyError("custom msg", desc="New desc.", data={"x": 1})
|
||||
assert str(err) == "custom msg"
|
||||
assert err.api_error.msg == "custom msg" # detail also updates msg
|
||||
assert err.api_error.desc == "New desc."
|
||||
assert err.api_error.data == {"x": 1}
|
||||
assert err.api_error.code == 400 # other fields unchanged
|
||||
|
||||
def test_class_api_error_not_mutated_after_instance_override(self):
|
||||
"""Raising with desc/data does not mutate the class-level api_error."""
|
||||
|
||||
class MyError(ApiException):
|
||||
api_error = ApiError(
|
||||
code=400, msg="Error", desc="Default.", err_code="ERR-400"
|
||||
)
|
||||
|
||||
MyError(desc="Changed", data={"x": 1})
|
||||
assert MyError.api_error.desc == "Default."
|
||||
assert MyError.api_error.data is None
|
||||
|
||||
def test_subclass_uses_super_with_desc_and_data(self):
|
||||
"""Subclasses can delegate detail/desc/data to super().__init__()."""
|
||||
|
||||
class BuildValidationError(ApiException):
|
||||
api_error = ApiError(
|
||||
code=422,
|
||||
msg="Build Validation Error",
|
||||
desc="The build configuration is invalid.",
|
||||
err_code="BUILD-422",
|
||||
)
|
||||
|
||||
def __init__(self, *errors: str) -> None:
|
||||
super().__init__(
|
||||
f"{len(errors)} validation error(s)",
|
||||
desc=", ".join(errors),
|
||||
data={"errors": [{"message": e} for e in errors]},
|
||||
)
|
||||
|
||||
err = BuildValidationError("Field A is required", "Field B is invalid")
|
||||
assert str(err) == "2 validation error(s)"
|
||||
assert err.api_error.msg == "2 validation error(s)" # detail set msg
|
||||
assert err.api_error.desc == "Field A is required, Field B is invalid"
|
||||
assert err.api_error.data == {
|
||||
"errors": [
|
||||
{"message": "Field A is required"},
|
||||
{"message": "Field B is invalid"},
|
||||
]
|
||||
}
|
||||
assert err.api_error.code == 422 # other fields unchanged
|
||||
|
||||
def test_detail_desc_data_in_http_response(self):
|
||||
"""detail/desc/data overrides all appear correctly in the FastAPI HTTP response."""
|
||||
|
||||
class DynamicError(ApiException):
|
||||
api_error = ApiError(
|
||||
code=400, msg="Error", desc="Default.", err_code="ERR-400"
|
||||
)
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(
|
||||
message,
|
||||
desc=f"Detail: {message}",
|
||||
data={"reason": message},
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
init_exceptions_handlers(app)
|
||||
|
||||
@app.get("/error")
|
||||
async def raise_error():
|
||||
raise DynamicError("something went wrong")
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/error")
|
||||
|
||||
assert response.status_code == 400
|
||||
body = response.json()
|
||||
assert body["message"] == "something went wrong"
|
||||
assert body["description"] == "Detail: something went wrong"
|
||||
assert body["data"] == {"reason": "something went wrong"}
|
||||
|
||||
|
||||
class TestApiExceptionGuard:
|
||||
"""Tests for the __init_subclass__ api_error guard."""
|
||||
|
||||
def test_missing_api_error_raises_type_error(self):
|
||||
"""Defining a subclass without api_error raises TypeError at class creation time."""
|
||||
with pytest.raises(
|
||||
TypeError, match="must define an 'api_error' class attribute"
|
||||
):
|
||||
|
||||
class BrokenError(ApiException):
|
||||
pass
|
||||
|
||||
def test_abstract_subclass_skips_guard(self):
|
||||
"""abstract=True allows intermediate base classes without api_error."""
|
||||
|
||||
class BaseGroupError(ApiException, abstract=True):
|
||||
pass
|
||||
|
||||
# Concrete child must still define it
|
||||
class ConcreteError(BaseGroupError):
|
||||
api_error = ApiError(
|
||||
code=400, msg="Error", desc="Desc.", err_code="ERR-400"
|
||||
)
|
||||
|
||||
err = ConcreteError()
|
||||
assert err.api_error.code == 400
|
||||
|
||||
def test_abstract_child_still_requires_api_error_on_concrete(self):
|
||||
"""Concrete subclass of an abstract class must define api_error."""
|
||||
|
||||
class Base(ApiException, abstract=True):
|
||||
pass
|
||||
|
||||
with pytest.raises(
|
||||
TypeError, match="must define an 'api_error' class attribute"
|
||||
):
|
||||
|
||||
class Concrete(Base):
|
||||
pass
|
||||
|
||||
def test_inherited_api_error_satisfies_guard(self):
|
||||
"""Subclass that inherits api_error from a parent does not need its own."""
|
||||
|
||||
class ConcreteError(NotFoundError):
|
||||
pass
|
||||
|
||||
err = ConcreteError()
|
||||
assert err.api_error.code == 404
|
||||
|
||||
|
||||
class TestBuiltInExceptions:
|
||||
@@ -89,7 +255,7 @@ class TestGenerateErrorResponses:
|
||||
assert responses[404]["description"] == "Not Found"
|
||||
|
||||
def test_generates_multiple_responses(self):
|
||||
"""Generates responses for multiple exceptions."""
|
||||
"""Generates responses for multiple exceptions with distinct status codes."""
|
||||
responses = generate_error_responses(
|
||||
UnauthorizedError,
|
||||
ForbiddenError,
|
||||
@@ -100,14 +266,81 @@ class TestGenerateErrorResponses:
|
||||
assert 403 in responses
|
||||
assert 404 in responses
|
||||
|
||||
def test_response_has_example(self):
|
||||
"""Generated response includes example."""
|
||||
def test_response_has_named_example(self):
|
||||
"""Generated response uses named examples keyed by err_code."""
|
||||
responses = generate_error_responses(NotFoundError)
|
||||
example = responses[404]["content"]["application/json"]["example"]
|
||||
examples = responses[404]["content"]["application/json"]["examples"]
|
||||
|
||||
assert example["status"] == "FAIL"
|
||||
assert example["error_code"] == "RES-404"
|
||||
assert example["message"] == "Not Found"
|
||||
assert "RES-404" in examples
|
||||
value = examples["RES-404"]["value"]
|
||||
assert value["status"] == "FAIL"
|
||||
assert value["error_code"] == "RES-404"
|
||||
assert value["message"] == "Not Found"
|
||||
assert value["data"] is None
|
||||
|
||||
def test_response_example_has_summary(self):
|
||||
"""Each named example carries a summary equal to api_error.msg."""
|
||||
responses = generate_error_responses(NotFoundError)
|
||||
example = responses[404]["content"]["application/json"]["examples"]["RES-404"]
|
||||
|
||||
assert example["summary"] == "Not Found"
|
||||
|
||||
def test_response_example_with_data(self):
|
||||
"""Generated response includes data when set on ApiError."""
|
||||
|
||||
class ErrorWithData(ApiException):
|
||||
api_error = ApiError(
|
||||
code=400,
|
||||
msg="Bad Request",
|
||||
desc="Invalid input.",
|
||||
err_code="BAD-400",
|
||||
data={"details": "some context"},
|
||||
)
|
||||
|
||||
responses = generate_error_responses(ErrorWithData)
|
||||
value = responses[400]["content"]["application/json"]["examples"]["BAD-400"][
|
||||
"value"
|
||||
]
|
||||
|
||||
assert value["data"] == {"details": "some context"}
|
||||
|
||||
def test_two_errors_same_code_both_present(self):
|
||||
"""Two exceptions with the same HTTP code produce two named examples."""
|
||||
|
||||
class BadRequestA(ApiException):
|
||||
api_error = ApiError(
|
||||
code=400, msg="Bad A", desc="Reason A.", err_code="ERR-A"
|
||||
)
|
||||
|
||||
class BadRequestB(ApiException):
|
||||
api_error = ApiError(
|
||||
code=400, msg="Bad B", desc="Reason B.", err_code="ERR-B"
|
||||
)
|
||||
|
||||
responses = generate_error_responses(BadRequestA, BadRequestB)
|
||||
|
||||
assert 400 in responses
|
||||
examples = responses[400]["content"]["application/json"]["examples"]
|
||||
assert "ERR-A" in examples
|
||||
assert "ERR-B" in examples
|
||||
assert examples["ERR-A"]["value"]["message"] == "Bad A"
|
||||
assert examples["ERR-B"]["value"]["message"] == "Bad B"
|
||||
|
||||
def test_two_errors_same_code_single_top_level_entry(self):
|
||||
"""Two exceptions with the same HTTP code produce exactly one top-level entry."""
|
||||
|
||||
class BadRequestA(ApiException):
|
||||
api_error = ApiError(
|
||||
code=400, msg="Bad A", desc="Reason A.", err_code="ERR-A"
|
||||
)
|
||||
|
||||
class BadRequestB(ApiException):
|
||||
api_error = ApiError(
|
||||
code=400, msg="Bad B", desc="Reason B.", err_code="ERR-B"
|
||||
)
|
||||
|
||||
responses = generate_error_responses(BadRequestA, BadRequestB)
|
||||
assert len([k for k in responses if k == 400]) == 1
|
||||
|
||||
|
||||
class TestInitExceptionsHandlers:
|
||||
@@ -137,6 +370,59 @@ class TestInitExceptionsHandlers:
|
||||
assert data["error_code"] == "RES-404"
|
||||
assert data["message"] == "Not Found"
|
||||
|
||||
def test_handles_api_exception_without_data(self):
|
||||
"""ApiException without data returns null data field."""
|
||||
app = FastAPI()
|
||||
init_exceptions_handlers(app)
|
||||
|
||||
@app.get("/error")
|
||||
async def raise_error():
|
||||
raise NotFoundError()
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/error")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.json()["data"] is None
|
||||
|
||||
def test_handles_api_exception_with_data(self):
|
||||
"""ApiException with data returns the data payload."""
|
||||
app = FastAPI()
|
||||
init_exceptions_handlers(app)
|
||||
|
||||
class CustomValidationError(ApiException):
|
||||
api_error = ApiError(
|
||||
code=422,
|
||||
msg="Validation Error",
|
||||
desc="1 validation error(s) detected",
|
||||
err_code="CUSTOM-422",
|
||||
data={
|
||||
"errors": [
|
||||
{
|
||||
"field": "email",
|
||||
"message": "invalid format",
|
||||
"type": "value_error",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@app.get("/error")
|
||||
async def raise_error():
|
||||
raise CustomValidationError()
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/error")
|
||||
|
||||
assert response.status_code == 422
|
||||
data = response.json()
|
||||
assert data["data"] == {
|
||||
"errors": [
|
||||
{"field": "email", "message": "invalid format", "type": "value_error"}
|
||||
]
|
||||
}
|
||||
assert data["error_code"] == "CUSTOM-422"
|
||||
|
||||
def test_handles_validation_error(self):
|
||||
"""Handles validation errors with structured response."""
|
||||
from pydantic import BaseModel
|
||||
@@ -178,13 +464,68 @@ class TestInitExceptionsHandlers:
|
||||
assert data["status"] == "FAIL"
|
||||
assert data["error_code"] == "SERVER-500"
|
||||
|
||||
def test_custom_openapi_schema(self):
|
||||
"""Customizes OpenAPI schema for 422 responses."""
|
||||
def test_handles_http_exception(self):
|
||||
"""Handles starlette HTTPException with consistent ErrorResponse envelope."""
|
||||
app = FastAPI()
|
||||
init_exceptions_handlers(app)
|
||||
|
||||
@app.get("/protected")
|
||||
async def protected():
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/protected")
|
||||
|
||||
assert response.status_code == 403
|
||||
data = response.json()
|
||||
assert data["status"] == "FAIL"
|
||||
assert data["error_code"] == "HTTP-403"
|
||||
assert data["message"] == "Forbidden"
|
||||
|
||||
def test_handles_http_exception_404_from_route(self):
|
||||
"""HTTPException(404) raised inside a route uses the consistent ErrorResponse envelope."""
|
||||
app = FastAPI()
|
||||
init_exceptions_handlers(app)
|
||||
|
||||
@app.get("/items/{item_id}")
|
||||
async def get_item(item_id: int):
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/items/99")
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert data["status"] == "FAIL"
|
||||
assert data["error_code"] == "HTTP-404"
|
||||
assert data["message"] == "Item not found"
|
||||
|
||||
def test_handles_http_exception_forwards_headers(self):
|
||||
"""HTTPException with WWW-Authenticate header forwards it in the response."""
|
||||
app = FastAPI()
|
||||
init_exceptions_handlers(app)
|
||||
|
||||
@app.get("/secure")
|
||||
async def secure():
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/secure")
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.headers.get("www-authenticate") == "Bearer"
|
||||
|
||||
def test_custom_openapi_schema(self):
|
||||
"""Customises OpenAPI schema for 422 responses using named examples."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
app = FastAPI()
|
||||
init_exceptions_handlers(app)
|
||||
|
||||
class Item(BaseModel):
|
||||
name: str
|
||||
|
||||
@@ -197,8 +538,128 @@ class TestInitExceptionsHandlers:
|
||||
post_op = openapi["paths"]["/items"]["post"]
|
||||
assert "422" in post_op["responses"]
|
||||
resp_422 = post_op["responses"]["422"]
|
||||
example = resp_422["content"]["application/json"]["example"]
|
||||
assert example["error_code"] == "VAL-422"
|
||||
examples = resp_422["content"]["application/json"]["examples"]
|
||||
assert "VAL-422" in examples
|
||||
assert examples["VAL-422"]["value"]["error_code"] == "VAL-422"
|
||||
|
||||
def test_custom_openapi_preserves_app_metadata(self):
|
||||
"""_patched_openapi preserves custom FastAPI app-level metadata."""
|
||||
app = FastAPI(
|
||||
title="My API",
|
||||
version="2.0.0",
|
||||
description="Custom description",
|
||||
)
|
||||
init_exceptions_handlers(app)
|
||||
|
||||
schema = app.openapi()
|
||||
assert schema["info"]["title"] == "My API"
|
||||
assert schema["info"]["version"] == "2.0.0"
|
||||
|
||||
def test_handles_response_validation_error(self):
|
||||
"""Handles ResponseValidationError with a structured 422 response."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class CountResponse(BaseModel):
|
||||
count: int
|
||||
|
||||
app = FastAPI()
|
||||
init_exceptions_handlers(app)
|
||||
|
||||
@app.get("/broken", response_model=CountResponse)
|
||||
async def broken():
|
||||
return {"count": "not-a-number"} # triggers ResponseValidationError
|
||||
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get("/broken")
|
||||
|
||||
assert response.status_code == 422
|
||||
data = response.json()
|
||||
assert data["status"] == "FAIL"
|
||||
assert data["error_code"] == "VAL-422"
|
||||
assert "errors" in data["data"]
|
||||
|
||||
def test_handles_validation_error_with_non_standard_loc(self):
|
||||
"""Validation error with empty loc tuple maps the field to 'root'."""
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
|
||||
app = FastAPI()
|
||||
init_exceptions_handlers(app)
|
||||
|
||||
@app.get("/root-error")
|
||||
async def root_error():
|
||||
raise RequestValidationError(
|
||||
[
|
||||
{
|
||||
"type": "custom",
|
||||
"loc": (),
|
||||
"msg": "root level error",
|
||||
"input": None,
|
||||
"url": "",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/root-error")
|
||||
|
||||
assert response.status_code == 422
|
||||
data = response.json()
|
||||
assert data["data"]["errors"][0]["field"] == "root"
|
||||
|
||||
def test_openapi_schema_cached_after_first_call(self):
|
||||
"""app.openapi() returns the cached schema on subsequent calls."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
app = FastAPI()
|
||||
init_exceptions_handlers(app)
|
||||
|
||||
class Item(BaseModel):
|
||||
name: str
|
||||
|
||||
@app.post("/items")
|
||||
async def create_item(item: Item):
|
||||
return item
|
||||
|
||||
schema_first = app.openapi()
|
||||
schema_second = app.openapi()
|
||||
assert schema_first is schema_second
|
||||
|
||||
def test_openapi_skips_operations_without_422(self):
|
||||
"""_patched_openapi leaves operations that have no 422 response unchanged."""
|
||||
app = FastAPI()
|
||||
init_exceptions_handlers(app)
|
||||
|
||||
@app.get("/ping")
|
||||
async def ping():
|
||||
return {"ok": True}
|
||||
|
||||
schema = app.openapi()
|
||||
get_op = schema["paths"]["/ping"]["get"]
|
||||
assert "422" not in get_op["responses"]
|
||||
assert "200" in get_op["responses"]
|
||||
|
||||
def test_openapi_skips_non_dict_path_item_values(self):
|
||||
"""_patched_openapi ignores non-dict values in path items (e.g. path-level parameters)."""
|
||||
from fastapi_toolsets.exceptions.handler import _patched_openapi
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
def fake_openapi() -> dict:
|
||||
return {
|
||||
"paths": {
|
||||
"/items": {
|
||||
"parameters": [
|
||||
{"name": "q", "in": "query"}
|
||||
], # list, not a dict
|
||||
"get": {"responses": {"200": {"description": "OK"}}},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
schema = _patched_openapi(app, fake_openapi)
|
||||
# The list value was skipped without error; the GET operation is intact
|
||||
assert schema["paths"]["/items"]["parameters"] == [{"name": "q", "in": "query"}]
|
||||
assert "422" not in schema["paths"]["/items"]["get"]["responses"]
|
||||
|
||||
|
||||
class TestExceptionIntegration:
|
||||
@@ -263,3 +724,43 @@ class TestExceptionIntegration:
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"id": 1}
|
||||
|
||||
|
||||
class TestInvalidOrderFieldError:
|
||||
"""Tests for InvalidOrderFieldError exception."""
|
||||
|
||||
def test_api_error_attributes(self):
|
||||
"""InvalidOrderFieldError has correct api_error metadata."""
|
||||
assert InvalidOrderFieldError.api_error.code == 422
|
||||
assert InvalidOrderFieldError.api_error.err_code == "SORT-422"
|
||||
assert InvalidOrderFieldError.api_error.msg == "Invalid Order Field"
|
||||
|
||||
def test_stores_field_and_valid_fields(self):
|
||||
"""InvalidOrderFieldError stores field and valid_fields on the instance."""
|
||||
error = InvalidOrderFieldError("unknown", ["name", "created_at"])
|
||||
assert error.field == "unknown"
|
||||
assert error.valid_fields == ["name", "created_at"]
|
||||
|
||||
def test_description_contains_field_and_valid_fields(self):
|
||||
"""api_error.desc mentions the bad field and valid options."""
|
||||
error = InvalidOrderFieldError("bad_field", ["name", "email"])
|
||||
assert "bad_field" in error.api_error.desc
|
||||
assert "name" in error.api_error.desc
|
||||
assert "email" in error.api_error.desc
|
||||
|
||||
def test_handled_as_422_by_exception_handler(self):
|
||||
"""init_exceptions_handlers turns InvalidOrderFieldError into a 422 response."""
|
||||
app = FastAPI()
|
||||
init_exceptions_handlers(app)
|
||||
|
||||
@app.get("/items")
|
||||
async def list_items():
|
||||
raise InvalidOrderFieldError("bad", ["name"])
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/items")
|
||||
|
||||
assert response.status_code == 422
|
||||
data = response.json()
|
||||
assert data["error_code"] == "SORT-422"
|
||||
assert data["status"] == "FAIL"
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Tests for fastapi_toolsets.fixtures module."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@@ -12,7 +14,9 @@ from fastapi_toolsets.fixtures import (
|
||||
load_fixtures_by_context,
|
||||
)
|
||||
|
||||
from .conftest import Role, User
|
||||
from fastapi_toolsets.fixtures.utils import _get_primary_key
|
||||
|
||||
from .conftest import IntRole, Permission, Role, RoleCrud, User, UserCrud
|
||||
|
||||
|
||||
class TestContext:
|
||||
@@ -57,20 +61,22 @@ class TestFixtureRegistry:
|
||||
def test_register_with_decorator(self):
|
||||
"""Register fixture with decorator."""
|
||||
registry = FixtureRegistry()
|
||||
role_id = uuid.uuid4()
|
||||
|
||||
@registry.register
|
||||
def roles():
|
||||
return [Role(id=1, name="admin")]
|
||||
return [Role(id=role_id, name="admin")]
|
||||
|
||||
assert "roles" in [f.name for f in registry.get_all()]
|
||||
|
||||
def test_register_with_custom_name(self):
|
||||
"""Register fixture with custom name."""
|
||||
registry = FixtureRegistry()
|
||||
role_id = uuid.uuid4()
|
||||
|
||||
@registry.register(name="custom_roles")
|
||||
def roles():
|
||||
return [Role(id=1, name="admin")]
|
||||
return [Role(id=role_id, name="admin")]
|
||||
|
||||
fixture = registry.get("custom_roles")
|
||||
assert fixture.name == "custom_roles"
|
||||
@@ -78,14 +84,23 @@ class TestFixtureRegistry:
|
||||
def test_register_with_dependencies(self):
|
||||
"""Register fixture with dependencies."""
|
||||
registry = FixtureRegistry()
|
||||
role_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
@registry.register
|
||||
def roles():
|
||||
return [Role(id=1, name="admin")]
|
||||
return [Role(id=role_id, name="admin")]
|
||||
|
||||
@registry.register(depends_on=["roles"])
|
||||
def users():
|
||||
return [User(id=1, username="admin", email="admin@test.com", role_id=1)]
|
||||
return [
|
||||
User(
|
||||
id=user_id,
|
||||
username="admin",
|
||||
email="admin@test.com",
|
||||
role_id=role_id,
|
||||
)
|
||||
]
|
||||
|
||||
fixture = registry.get("users")
|
||||
assert fixture.depends_on == ["roles"]
|
||||
@@ -93,10 +108,11 @@ class TestFixtureRegistry:
|
||||
def test_register_with_contexts(self):
|
||||
"""Register fixture with contexts."""
|
||||
registry = FixtureRegistry()
|
||||
role_id = uuid.uuid4()
|
||||
|
||||
@registry.register(contexts=[Context.TESTING])
|
||||
def test_data():
|
||||
return [Role(id=100, name="test")]
|
||||
return [Role(id=role_id, name="test")]
|
||||
|
||||
fixture = registry.get("test_data")
|
||||
assert Context.TESTING.value in fixture.contexts
|
||||
@@ -145,6 +161,178 @@ class TestFixtureRegistry:
|
||||
assert names == {"test_data"}
|
||||
|
||||
|
||||
class TestIncludeRegistry:
|
||||
"""Tests for FixtureRegistry.include_registry method."""
|
||||
|
||||
def test_include_empty_registry(self):
|
||||
"""Include an empty registry does nothing."""
|
||||
main_registry = FixtureRegistry()
|
||||
other_registry = FixtureRegistry()
|
||||
|
||||
@main_registry.register
|
||||
def roles():
|
||||
return []
|
||||
|
||||
main_registry.include_registry(other_registry)
|
||||
|
||||
assert len(main_registry.get_all()) == 1
|
||||
|
||||
def test_include_registry_adds_fixtures(self):
|
||||
"""Include registry adds all fixtures from the other registry."""
|
||||
main_registry = FixtureRegistry()
|
||||
other_registry = FixtureRegistry()
|
||||
|
||||
@main_registry.register
|
||||
def roles():
|
||||
return []
|
||||
|
||||
@other_registry.register
|
||||
def users():
|
||||
return []
|
||||
|
||||
@other_registry.register
|
||||
def posts():
|
||||
return []
|
||||
|
||||
main_registry.include_registry(other_registry)
|
||||
|
||||
names = {f.name for f in main_registry.get_all()}
|
||||
assert names == {"roles", "users", "posts"}
|
||||
|
||||
def test_include_registry_preserves_dependencies(self):
|
||||
"""Include registry preserves fixture dependencies."""
|
||||
main_registry = FixtureRegistry()
|
||||
other_registry = FixtureRegistry()
|
||||
|
||||
@main_registry.register
|
||||
def roles():
|
||||
return []
|
||||
|
||||
@other_registry.register(depends_on=["roles"])
|
||||
def users():
|
||||
return []
|
||||
|
||||
main_registry.include_registry(other_registry)
|
||||
|
||||
fixture = main_registry.get("users")
|
||||
assert fixture.depends_on == ["roles"]
|
||||
|
||||
def test_include_registry_preserves_contexts(self):
|
||||
"""Include registry preserves fixture contexts."""
|
||||
main_registry = FixtureRegistry()
|
||||
other_registry = FixtureRegistry()
|
||||
|
||||
@other_registry.register(contexts=[Context.TESTING, Context.DEVELOPMENT])
|
||||
def test_data():
|
||||
return []
|
||||
|
||||
main_registry.include_registry(other_registry)
|
||||
|
||||
fixture = main_registry.get("test_data")
|
||||
assert Context.TESTING.value in fixture.contexts
|
||||
assert Context.DEVELOPMENT.value in fixture.contexts
|
||||
|
||||
def test_include_registry_raises_on_duplicate(self):
|
||||
"""Include registry raises ValueError on duplicate fixture names."""
|
||||
main_registry = FixtureRegistry()
|
||||
other_registry = FixtureRegistry()
|
||||
|
||||
@main_registry.register(name="roles")
|
||||
def roles_main():
|
||||
return []
|
||||
|
||||
@other_registry.register(name="roles")
|
||||
def roles_other():
|
||||
return []
|
||||
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
main_registry.include_registry(other_registry)
|
||||
|
||||
def test_include_multiple_registries(self):
|
||||
"""Include multiple registries sequentially."""
|
||||
main_registry = FixtureRegistry()
|
||||
dev_registry = FixtureRegistry()
|
||||
test_registry = FixtureRegistry()
|
||||
|
||||
@main_registry.register
|
||||
def base():
|
||||
return []
|
||||
|
||||
@dev_registry.register
|
||||
def dev_data():
|
||||
return []
|
||||
|
||||
@test_registry.register
|
||||
def test_data():
|
||||
return []
|
||||
|
||||
main_registry.include_registry(dev_registry)
|
||||
main_registry.include_registry(test_registry)
|
||||
|
||||
names = {f.name for f in main_registry.get_all()}
|
||||
assert names == {"base", "dev_data", "test_data"}
|
||||
|
||||
|
||||
class TestDefaultContexts:
|
||||
"""Tests for FixtureRegistry default contexts."""
|
||||
|
||||
def test_default_contexts_applied_to_fixtures(self):
|
||||
"""Default contexts are applied when no contexts specified."""
|
||||
registry = FixtureRegistry(contexts=[Context.TESTING])
|
||||
|
||||
@registry.register
|
||||
def test_data():
|
||||
return []
|
||||
|
||||
fixture = registry.get("test_data")
|
||||
assert fixture.contexts == [Context.TESTING.value]
|
||||
|
||||
def test_explicit_contexts_override_default(self):
|
||||
"""Explicit contexts override default contexts."""
|
||||
registry = FixtureRegistry(contexts=[Context.TESTING])
|
||||
|
||||
@registry.register(contexts=[Context.PRODUCTION])
|
||||
def prod_data():
|
||||
return []
|
||||
|
||||
fixture = registry.get("prod_data")
|
||||
assert fixture.contexts == [Context.PRODUCTION.value]
|
||||
|
||||
def test_no_default_contexts_uses_base(self):
|
||||
"""Without default contexts, BASE is used."""
|
||||
registry = FixtureRegistry()
|
||||
|
||||
@registry.register
|
||||
def data():
|
||||
return []
|
||||
|
||||
fixture = registry.get("data")
|
||||
assert fixture.contexts == [Context.BASE.value]
|
||||
|
||||
def test_multiple_default_contexts(self):
|
||||
"""Multiple default contexts are applied."""
|
||||
registry = FixtureRegistry(contexts=[Context.DEVELOPMENT, Context.TESTING])
|
||||
|
||||
@registry.register
|
||||
def dev_test_data():
|
||||
return []
|
||||
|
||||
fixture = registry.get("dev_test_data")
|
||||
assert Context.DEVELOPMENT.value in fixture.contexts
|
||||
assert Context.TESTING.value in fixture.contexts
|
||||
|
||||
def test_default_contexts_with_string_values(self):
|
||||
"""Default contexts work with string values."""
|
||||
registry = FixtureRegistry(contexts=["custom_context"])
|
||||
|
||||
@registry.register
|
||||
def custom_data():
|
||||
return []
|
||||
|
||||
fixture = registry.get("custom_data")
|
||||
assert fixture.contexts == ["custom_context"]
|
||||
|
||||
|
||||
class TestDependencyResolution:
|
||||
"""Tests for fixture dependency resolution."""
|
||||
|
||||
@@ -244,12 +432,14 @@ class TestLoadFixtures:
|
||||
async def test_load_single_fixture(self, db_session: AsyncSession):
|
||||
"""Load a single fixture."""
|
||||
registry = FixtureRegistry()
|
||||
role_id_1 = uuid.uuid4()
|
||||
role_id_2 = uuid.uuid4()
|
||||
|
||||
@registry.register
|
||||
def roles():
|
||||
return [
|
||||
Role(id=1, name="admin"),
|
||||
Role(id=2, name="user"),
|
||||
Role(id=role_id_1, name="admin"),
|
||||
Role(id=role_id_2, name="user"),
|
||||
]
|
||||
|
||||
result = await load_fixtures(db_session, registry, "roles")
|
||||
@@ -257,8 +447,6 @@ class TestLoadFixtures:
|
||||
assert "roles" in result
|
||||
assert len(result["roles"]) == 2
|
||||
|
||||
from .conftest import RoleCrud
|
||||
|
||||
count = await RoleCrud.count(db_session)
|
||||
assert count == 2
|
||||
|
||||
@@ -266,22 +454,29 @@ class TestLoadFixtures:
|
||||
async def test_load_with_dependencies(self, db_session: AsyncSession):
|
||||
"""Load fixtures with dependencies."""
|
||||
registry = FixtureRegistry()
|
||||
role_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
@registry.register
|
||||
def roles():
|
||||
return [Role(id=1, name="admin")]
|
||||
return [Role(id=role_id, name="admin")]
|
||||
|
||||
@registry.register(depends_on=["roles"])
|
||||
def users():
|
||||
return [User(id=1, username="admin", email="admin@test.com", role_id=1)]
|
||||
return [
|
||||
User(
|
||||
id=user_id,
|
||||
username="admin",
|
||||
email="admin@test.com",
|
||||
role_id=role_id,
|
||||
)
|
||||
]
|
||||
|
||||
result = await load_fixtures(db_session, registry, "users")
|
||||
|
||||
assert "roles" in result
|
||||
assert "users" in result
|
||||
|
||||
from .conftest import RoleCrud, UserCrud
|
||||
|
||||
assert await RoleCrud.count(db_session) == 1
|
||||
assert await UserCrud.count(db_session) == 1
|
||||
|
||||
@@ -289,16 +484,15 @@ class TestLoadFixtures:
|
||||
async def test_load_with_merge_strategy(self, db_session: AsyncSession):
|
||||
"""Load fixtures with MERGE strategy updates existing."""
|
||||
registry = FixtureRegistry()
|
||||
role_id = uuid.uuid4()
|
||||
|
||||
@registry.register
|
||||
def roles():
|
||||
return [Role(id=1, name="admin")]
|
||||
return [Role(id=role_id, name="admin")]
|
||||
|
||||
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
|
||||
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
|
||||
|
||||
from .conftest import RoleCrud
|
||||
|
||||
count = await RoleCrud.count(db_session)
|
||||
assert count == 1
|
||||
|
||||
@@ -306,10 +500,11 @@ class TestLoadFixtures:
|
||||
async def test_load_with_skip_existing_strategy(self, db_session: AsyncSession):
|
||||
"""Load fixtures with SKIP_EXISTING strategy."""
|
||||
registry = FixtureRegistry()
|
||||
role_id = uuid.uuid4()
|
||||
|
||||
@registry.register
|
||||
def roles():
|
||||
return [Role(id=1, name="original")]
|
||||
return [Role(id=role_id, name="original")]
|
||||
|
||||
await load_fixtures(
|
||||
db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING
|
||||
@@ -317,7 +512,7 @@ class TestLoadFixtures:
|
||||
|
||||
@registry.register(name="roles_updated")
|
||||
def roles_v2():
|
||||
return [Role(id=1, name="updated")]
|
||||
return [Role(id=role_id, name="updated")]
|
||||
|
||||
registry._fixtures["roles"] = registry._fixtures.pop("roles_updated")
|
||||
|
||||
@@ -325,9 +520,7 @@ class TestLoadFixtures:
|
||||
db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING
|
||||
)
|
||||
|
||||
from .conftest import RoleCrud
|
||||
|
||||
role = await RoleCrud.first(db_session, [Role.id == 1])
|
||||
role = await RoleCrud.first(db_session, [Role.id == role_id])
|
||||
assert role is not None
|
||||
assert role.name == "original"
|
||||
|
||||
@@ -335,12 +528,14 @@ class TestLoadFixtures:
|
||||
async def test_load_with_insert_strategy(self, db_session: AsyncSession):
|
||||
"""Load fixtures with INSERT strategy."""
|
||||
registry = FixtureRegistry()
|
||||
role_id_1 = uuid.uuid4()
|
||||
role_id_2 = uuid.uuid4()
|
||||
|
||||
@registry.register
|
||||
def roles():
|
||||
return [
|
||||
Role(id=1, name="admin"),
|
||||
Role(id=2, name="user"),
|
||||
Role(id=role_id_1, name="admin"),
|
||||
Role(id=role_id_2, name="user"),
|
||||
]
|
||||
|
||||
result = await load_fixtures(
|
||||
@@ -350,8 +545,6 @@ class TestLoadFixtures:
|
||||
assert "roles" in result
|
||||
assert len(result["roles"]) == 2
|
||||
|
||||
from .conftest import RoleCrud
|
||||
|
||||
count = await RoleCrud.count(db_session)
|
||||
assert count == 2
|
||||
|
||||
@@ -375,25 +568,65 @@ class TestLoadFixtures:
|
||||
):
|
||||
"""Load multiple independent fixtures."""
|
||||
registry = FixtureRegistry()
|
||||
role_id_1 = uuid.uuid4()
|
||||
role_id_2 = uuid.uuid4()
|
||||
|
||||
@registry.register
|
||||
def roles():
|
||||
return [Role(id=1, name="admin")]
|
||||
return [Role(id=role_id_1, name="admin")]
|
||||
|
||||
@registry.register
|
||||
def other_roles():
|
||||
return [Role(id=2, name="user")]
|
||||
return [Role(id=role_id_2, name="user")]
|
||||
|
||||
result = await load_fixtures(db_session, registry, "roles", "other_roles")
|
||||
|
||||
assert "roles" in result
|
||||
assert "other_roles" in result
|
||||
|
||||
from .conftest import RoleCrud
|
||||
|
||||
count = await RoleCrud.count(db_session)
|
||||
assert count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_skip_existing_skips_if_record_exists(self, db_session: AsyncSession):
|
||||
"""SKIP_EXISTING returns empty loaded list when the record already exists."""
|
||||
registry = FixtureRegistry()
|
||||
role_id = uuid.uuid4()
|
||||
|
||||
@registry.register
|
||||
def roles():
|
||||
return [Role(id=role_id, name="admin")]
|
||||
|
||||
# First load — inserts the record.
|
||||
result1 = await load_fixtures(
|
||||
db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING
|
||||
)
|
||||
assert len(result1["roles"]) == 1
|
||||
|
||||
# Remove from identity map so session.get() queries the DB in the second load.
|
||||
db_session.expunge_all()
|
||||
|
||||
# Second load — record exists in DB, nothing should be added.
|
||||
result2 = await load_fixtures(
|
||||
db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING
|
||||
)
|
||||
assert result2["roles"] == []
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_skip_existing_null_pk_inserts(self, db_session: AsyncSession):
|
||||
"""SKIP_EXISTING inserts when the instance has no PK set (auto-increment)."""
|
||||
registry = FixtureRegistry()
|
||||
|
||||
@registry.register
|
||||
def int_roles():
|
||||
# No id provided — PK is None before INSERT (autoincrement).
|
||||
return [IntRole(name="member")]
|
||||
|
||||
result = await load_fixtures(
|
||||
db_session, registry, "int_roles", strategy=LoadStrategy.SKIP_EXISTING
|
||||
)
|
||||
assert len(result["int_roles"]) == 1
|
||||
|
||||
|
||||
class TestLoadFixturesByContext:
|
||||
"""Tests for load_fixtures_by_context function."""
|
||||
@@ -402,23 +635,23 @@ class TestLoadFixturesByContext:
|
||||
async def test_load_by_single_context(self, db_session: AsyncSession):
|
||||
"""Load fixtures by single context."""
|
||||
registry = FixtureRegistry()
|
||||
base_role_id = uuid.uuid4()
|
||||
test_role_id = uuid.uuid4()
|
||||
|
||||
@registry.register(contexts=[Context.BASE])
|
||||
def base_roles():
|
||||
return [Role(id=1, name="base_role")]
|
||||
return [Role(id=base_role_id, name="base_role")]
|
||||
|
||||
@registry.register(contexts=[Context.TESTING])
|
||||
def test_roles():
|
||||
return [Role(id=100, name="test_role")]
|
||||
return [Role(id=test_role_id, name="test_role")]
|
||||
|
||||
await load_fixtures_by_context(db_session, registry, Context.BASE)
|
||||
|
||||
from .conftest import RoleCrud
|
||||
|
||||
count = await RoleCrud.count(db_session)
|
||||
assert count == 1
|
||||
|
||||
role = await RoleCrud.first(db_session, [Role.id == 1])
|
||||
role = await RoleCrud.first(db_session, [Role.id == base_role_id])
|
||||
assert role is not None
|
||||
assert role.name == "base_role"
|
||||
|
||||
@@ -426,21 +659,21 @@ class TestLoadFixturesByContext:
|
||||
async def test_load_by_multiple_contexts(self, db_session: AsyncSession):
|
||||
"""Load fixtures by multiple contexts."""
|
||||
registry = FixtureRegistry()
|
||||
base_role_id = uuid.uuid4()
|
||||
test_role_id = uuid.uuid4()
|
||||
|
||||
@registry.register(contexts=[Context.BASE])
|
||||
def base_roles():
|
||||
return [Role(id=1, name="base_role")]
|
||||
return [Role(id=base_role_id, name="base_role")]
|
||||
|
||||
@registry.register(contexts=[Context.TESTING])
|
||||
def test_roles():
|
||||
return [Role(id=100, name="test_role")]
|
||||
return [Role(id=test_role_id, name="test_role")]
|
||||
|
||||
await load_fixtures_by_context(
|
||||
db_session, registry, Context.BASE, Context.TESTING
|
||||
)
|
||||
|
||||
from .conftest import RoleCrud
|
||||
|
||||
count = await RoleCrud.count(db_session)
|
||||
assert count == 2
|
||||
|
||||
@@ -448,19 +681,26 @@ class TestLoadFixturesByContext:
|
||||
async def test_load_context_with_dependencies(self, db_session: AsyncSession):
|
||||
"""Load context fixtures with cross-context dependencies."""
|
||||
registry = FixtureRegistry()
|
||||
role_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
@registry.register(contexts=[Context.BASE])
|
||||
def roles():
|
||||
return [Role(id=1, name="admin")]
|
||||
return [Role(id=role_id, name="admin")]
|
||||
|
||||
@registry.register(depends_on=["roles"], contexts=[Context.TESTING])
|
||||
def test_users():
|
||||
return [User(id=1, username="tester", email="test@test.com", role_id=1)]
|
||||
return [
|
||||
User(
|
||||
id=user_id,
|
||||
username="tester",
|
||||
email="test@test.com",
|
||||
role_id=role_id,
|
||||
)
|
||||
]
|
||||
|
||||
await load_fixtures_by_context(db_session, registry, Context.TESTING)
|
||||
|
||||
from .conftest import RoleCrud, UserCrud
|
||||
|
||||
assert await RoleCrud.count(db_session) == 1
|
||||
assert await UserCrud.count(db_session) == 1
|
||||
|
||||
@@ -471,20 +711,41 @@ class TestGetObjByAttr:
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures for each test."""
|
||||
self.registry = FixtureRegistry()
|
||||
self.role_id_1 = uuid.uuid4()
|
||||
self.role_id_2 = uuid.uuid4()
|
||||
self.role_id_3 = uuid.uuid4()
|
||||
self.user_id_1 = uuid.uuid4()
|
||||
self.user_id_2 = uuid.uuid4()
|
||||
|
||||
role_id_1 = self.role_id_1
|
||||
role_id_2 = self.role_id_2
|
||||
role_id_3 = self.role_id_3
|
||||
user_id_1 = self.user_id_1
|
||||
user_id_2 = self.user_id_2
|
||||
|
||||
@self.registry.register
|
||||
def roles() -> list[Role]:
|
||||
return [
|
||||
Role(id=1, name="admin"),
|
||||
Role(id=2, name="user"),
|
||||
Role(id=3, name="moderator"),
|
||||
Role(id=role_id_1, name="admin"),
|
||||
Role(id=role_id_2, name="user"),
|
||||
Role(id=role_id_3, name="moderator"),
|
||||
]
|
||||
|
||||
@self.registry.register(depends_on=["roles"])
|
||||
def users() -> list[User]:
|
||||
return [
|
||||
User(id=1, username="alice", email="alice@example.com", role_id=1),
|
||||
User(id=2, username="bob", email="bob@example.com", role_id=1),
|
||||
User(
|
||||
id=user_id_1,
|
||||
username="alice",
|
||||
email="alice@example.com",
|
||||
role_id=role_id_1,
|
||||
),
|
||||
User(
|
||||
id=user_id_2,
|
||||
username="bob",
|
||||
email="bob@example.com",
|
||||
role_id=role_id_1,
|
||||
),
|
||||
]
|
||||
|
||||
self.roles = roles
|
||||
@@ -492,26 +753,45 @@ class TestGetObjByAttr:
|
||||
|
||||
def test_get_by_id(self):
|
||||
"""Get an object by its id attribute."""
|
||||
role = get_obj_by_attr(self.roles, "id", 1)
|
||||
role = get_obj_by_attr(self.roles, "id", self.role_id_1)
|
||||
assert role.name == "admin"
|
||||
|
||||
def test_get_user_by_username(self):
|
||||
"""Get a user by username."""
|
||||
user = get_obj_by_attr(self.users, "username", "bob")
|
||||
assert user.id == 2
|
||||
assert user.id == self.user_id_2
|
||||
assert user.email == "bob@example.com"
|
||||
|
||||
def test_returns_first_match(self):
|
||||
"""Returns the first matching object when multiple could match."""
|
||||
user = get_obj_by_attr(self.users, "role_id", 1)
|
||||
user = get_obj_by_attr(self.users, "role_id", self.role_id_1)
|
||||
assert user.username == "alice"
|
||||
|
||||
def test_no_match_raises_stop_iteration(self):
|
||||
"""Raises StopIteration when no object matches."""
|
||||
with pytest.raises(StopIteration):
|
||||
"""Raises StopIteration with contextual message when no object matches."""
|
||||
with pytest.raises(
|
||||
StopIteration,
|
||||
match="No object with name=nonexistent found in fixture 'roles'",
|
||||
):
|
||||
get_obj_by_attr(self.roles, "name", "nonexistent")
|
||||
|
||||
def test_no_match_on_wrong_value_type(self):
|
||||
"""Raises StopIteration when value type doesn't match."""
|
||||
with pytest.raises(StopIteration):
|
||||
get_obj_by_attr(self.roles, "id", "1")
|
||||
get_obj_by_attr(self.roles, "id", "not-a-uuid")
|
||||
|
||||
|
||||
class TestGetPrimaryKey:
|
||||
"""Unit tests for the _get_primary_key helper (composite PK paths)."""
|
||||
|
||||
def test_composite_pk_all_set(self):
|
||||
"""Returns a tuple when all composite PK values are set."""
|
||||
instance = Permission(subject="post", action="read")
|
||||
pk = _get_primary_key(instance)
|
||||
assert pk == ("post", "read")
|
||||
|
||||
def test_composite_pk_partial_none(self):
|
||||
"""Returns None when any composite PK value is None."""
|
||||
instance = Permission(subject="post") # action is None
|
||||
pk = _get_primary_key(instance)
|
||||
assert pk is None
|
||||
|
||||
210
tests/test_imports.py
Normal file
210
tests/test_imports.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Tests for optional dependency import guards."""
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from fastapi_toolsets._imports import require_extra
|
||||
|
||||
|
||||
class TestRequireExtra:
|
||||
"""Tests for the require_extra helper."""
|
||||
|
||||
def test_raises_import_error(self):
|
||||
"""require_extra raises ImportError."""
|
||||
with pytest.raises(ImportError):
|
||||
require_extra(package="some_pkg", extra="some_extra")
|
||||
|
||||
def test_error_message_contains_package_name(self):
|
||||
"""Error message mentions the missing package."""
|
||||
with pytest.raises(ImportError, match="'prometheus_client'"):
|
||||
require_extra(package="prometheus_client", extra="metrics")
|
||||
|
||||
def test_error_message_contains_install_instruction(self):
|
||||
"""Error message contains the pip install command."""
|
||||
with pytest.raises(
|
||||
ImportError, match=r"pip install fastapi-toolsets\[metrics\]"
|
||||
):
|
||||
require_extra(package="prometheus_client", extra="metrics")
|
||||
|
||||
|
||||
def _reload_without_package(module_path: str, blocked_packages: list[str]):
|
||||
"""Reload a module while blocking specific package imports.
|
||||
|
||||
Removes the target module and its parents from sys.modules so they
|
||||
get re-imported, and patches builtins.__import__ to raise ImportError
|
||||
for *blocked_packages*.
|
||||
"""
|
||||
# Remove cached modules so they get re-imported
|
||||
to_remove = [
|
||||
key
|
||||
for key in sys.modules
|
||||
if key == module_path or key.startswith(module_path + ".")
|
||||
]
|
||||
saved = {}
|
||||
for key in to_remove:
|
||||
saved[key] = sys.modules.pop(key)
|
||||
|
||||
# Also remove parent package to force re-execution of __init__.py
|
||||
parts = module_path.rsplit(".", 1)
|
||||
if len(parts) == 2:
|
||||
parent = parts[0]
|
||||
parent_keys = [
|
||||
key for key in sys.modules if key == parent or key.startswith(parent + ".")
|
||||
]
|
||||
for key in parent_keys:
|
||||
if key not in saved:
|
||||
saved[key] = sys.modules.pop(key)
|
||||
|
||||
original_import = (
|
||||
__builtins__.__import__ if hasattr(__builtins__, "__import__") else __import__
|
||||
)
|
||||
|
||||
def blocking_import(name, *args, **kwargs):
|
||||
for blocked in blocked_packages:
|
||||
if name == blocked or name.startswith(blocked + "."):
|
||||
raise ImportError(f"Mocked: No module named '{name}'")
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
return saved, blocking_import
|
||||
|
||||
|
||||
class TestMetricsImportGuard:
|
||||
"""Tests for metrics module import guard when prometheus_client is missing."""
|
||||
|
||||
def test_registry_imports_without_prometheus(self):
|
||||
"""Metric and MetricsRegistry are importable without prometheus_client."""
|
||||
saved, blocking_import = _reload_without_package(
|
||||
"fastapi_toolsets.metrics", ["prometheus_client"]
|
||||
)
|
||||
try:
|
||||
with patch("builtins.__import__", side_effect=blocking_import):
|
||||
mod = importlib.import_module("fastapi_toolsets.metrics")
|
||||
# Registry types should be available (they're stdlib-only)
|
||||
assert hasattr(mod, "Metric")
|
||||
assert hasattr(mod, "MetricsRegistry")
|
||||
finally:
|
||||
# Restore original modules
|
||||
for key in list(sys.modules):
|
||||
if key.startswith("fastapi_toolsets.metrics"):
|
||||
sys.modules.pop(key, None)
|
||||
sys.modules.update(saved)
|
||||
|
||||
def test_init_metrics_stub_raises_without_prometheus(self):
|
||||
"""init_metrics raises ImportError when prometheus_client is missing."""
|
||||
saved, blocking_import = _reload_without_package(
|
||||
"fastapi_toolsets.metrics", ["prometheus_client"]
|
||||
)
|
||||
try:
|
||||
with patch("builtins.__import__", side_effect=blocking_import):
|
||||
mod = importlib.import_module("fastapi_toolsets.metrics")
|
||||
with pytest.raises(ImportError, match="prometheus_client"):
|
||||
mod.init_metrics(None, None) # type: ignore[arg-type]
|
||||
finally:
|
||||
for key in list(sys.modules):
|
||||
if key.startswith("fastapi_toolsets.metrics"):
|
||||
sys.modules.pop(key, None)
|
||||
sys.modules.update(saved)
|
||||
|
||||
def test_init_metrics_works_with_prometheus(self):
|
||||
"""init_metrics is the real function when prometheus_client is available."""
|
||||
from fastapi_toolsets.metrics import init_metrics
|
||||
|
||||
# Should be the real function, not a stub
|
||||
assert init_metrics.__module__ == "fastapi_toolsets.metrics.handler"
|
||||
|
||||
|
||||
class TestPytestImportGuard:
|
||||
"""Tests for pytest module import guard when dependencies are missing."""
|
||||
|
||||
def test_import_raises_without_pytest_package(self):
|
||||
"""Importing fastapi_toolsets.pytest raises when pytest is missing."""
|
||||
saved, blocking_import = _reload_without_package(
|
||||
"fastapi_toolsets.pytest", ["pytest"]
|
||||
)
|
||||
try:
|
||||
with patch("builtins.__import__", side_effect=blocking_import):
|
||||
with pytest.raises(ImportError, match="pytest"):
|
||||
importlib.import_module("fastapi_toolsets.pytest")
|
||||
finally:
|
||||
for key in list(sys.modules):
|
||||
if key.startswith("fastapi_toolsets.pytest"):
|
||||
sys.modules.pop(key, None)
|
||||
sys.modules.update(saved)
|
||||
|
||||
def test_import_raises_without_httpx(self):
|
||||
"""Importing fastapi_toolsets.pytest raises when httpx is missing."""
|
||||
saved, blocking_import = _reload_without_package(
|
||||
"fastapi_toolsets.pytest", ["httpx"]
|
||||
)
|
||||
try:
|
||||
with patch("builtins.__import__", side_effect=blocking_import):
|
||||
with pytest.raises(ImportError, match="httpx"):
|
||||
importlib.import_module("fastapi_toolsets.pytest")
|
||||
finally:
|
||||
for key in list(sys.modules):
|
||||
if key.startswith("fastapi_toolsets.pytest"):
|
||||
sys.modules.pop(key, None)
|
||||
sys.modules.update(saved)
|
||||
|
||||
def test_all_exports_available_with_deps(self):
|
||||
"""All expected exports are available when deps are installed."""
|
||||
from fastapi_toolsets.pytest import (
|
||||
cleanup_tables,
|
||||
create_async_client,
|
||||
create_db_session,
|
||||
create_worker_database,
|
||||
register_fixtures,
|
||||
worker_database_url,
|
||||
)
|
||||
|
||||
assert callable(register_fixtures)
|
||||
assert callable(create_async_client)
|
||||
assert callable(create_db_session)
|
||||
assert callable(create_worker_database)
|
||||
assert callable(worker_database_url)
|
||||
assert callable(cleanup_tables)
|
||||
|
||||
|
||||
class TestCliImportGuard:
|
||||
"""Tests for CLI module import guard when typer is missing."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"expected_match",
|
||||
[
|
||||
"typer",
|
||||
r"pip install fastapi-toolsets\[cli\]",
|
||||
],
|
||||
)
|
||||
def test_import_raises_without_typer(self, expected_match):
|
||||
"""Importing cli.app raises when typer is missing, with an informative error message."""
|
||||
saved, blocking_import = _reload_without_package(
|
||||
"fastapi_toolsets.cli.app", ["typer"]
|
||||
)
|
||||
# Also remove cli.config since it imports typer too
|
||||
config_keys = [
|
||||
k for k in sys.modules if k.startswith("fastapi_toolsets.cli.config")
|
||||
]
|
||||
for key in config_keys:
|
||||
if key not in saved:
|
||||
saved[key] = sys.modules.pop(key)
|
||||
|
||||
try:
|
||||
with patch("builtins.__import__", side_effect=blocking_import):
|
||||
with pytest.raises(ImportError, match=expected_match):
|
||||
importlib.import_module("fastapi_toolsets.cli.app")
|
||||
finally:
|
||||
for key in list(sys.modules):
|
||||
if key.startswith("fastapi_toolsets.cli.app") or key.startswith(
|
||||
"fastapi_toolsets.cli.config"
|
||||
):
|
||||
sys.modules.pop(key, None)
|
||||
sys.modules.update(saved)
|
||||
|
||||
def test_async_command_imports_without_typer(self):
|
||||
"""async_command is importable without typer (stdlib only)."""
|
||||
from fastapi_toolsets.cli import async_command
|
||||
|
||||
assert callable(async_command)
|
||||
118
tests/test_logger.py
Normal file
118
tests/test_logger.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from fastapi_toolsets.logger import (
|
||||
DEFAULT_FORMAT,
|
||||
UVICORN_LOGGERS,
|
||||
configure_logging,
|
||||
get_logger,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_loggers():
|
||||
"""Reset the root and uvicorn loggers after each test."""
|
||||
yield
|
||||
root = logging.getLogger()
|
||||
root.handlers.clear()
|
||||
root.setLevel(logging.WARNING)
|
||||
for name in UVICORN_LOGGERS:
|
||||
uv = logging.getLogger(name)
|
||||
uv.handlers.clear()
|
||||
uv.setLevel(logging.NOTSET)
|
||||
|
||||
|
||||
class TestConfigureLogging:
|
||||
def test_sets_up_handler_and_format(self):
|
||||
logger = configure_logging()
|
||||
|
||||
assert len(logger.handlers) == 1
|
||||
handler = logger.handlers[0]
|
||||
assert isinstance(handler, logging.StreamHandler)
|
||||
assert handler.stream is sys.stdout
|
||||
assert handler.formatter is not None
|
||||
assert handler.formatter._fmt == DEFAULT_FORMAT
|
||||
|
||||
def test_default_level_is_info(self):
|
||||
logger = configure_logging()
|
||||
|
||||
assert logger.level == logging.INFO
|
||||
|
||||
def test_custom_level_string(self):
|
||||
logger = configure_logging(level="DEBUG")
|
||||
|
||||
assert logger.level == logging.DEBUG
|
||||
|
||||
def test_custom_level_int(self):
|
||||
logger = configure_logging(level=logging.WARNING)
|
||||
|
||||
assert logger.level == logging.WARNING
|
||||
|
||||
def test_custom_format(self):
|
||||
custom_fmt = "%(levelname)s: %(message)s"
|
||||
logger = configure_logging(fmt=custom_fmt)
|
||||
|
||||
handler = logger.handlers[0]
|
||||
assert handler.formatter is not None
|
||||
assert handler.formatter._fmt == custom_fmt
|
||||
|
||||
def test_named_logger(self):
|
||||
logger = configure_logging(logger_name="myapp")
|
||||
|
||||
assert logger.name == "myapp"
|
||||
assert len(logger.handlers) == 1
|
||||
|
||||
def test_default_configures_root_logger(self):
|
||||
logger = configure_logging()
|
||||
|
||||
assert logger is logging.getLogger()
|
||||
|
||||
def test_idempotent_no_duplicate_handlers(self):
|
||||
configure_logging()
|
||||
configure_logging()
|
||||
logger = configure_logging()
|
||||
|
||||
assert len(logger.handlers) == 1
|
||||
|
||||
def test_configures_uvicorn_loggers(self):
|
||||
configure_logging(level="DEBUG")
|
||||
|
||||
for name in UVICORN_LOGGERS:
|
||||
uv_logger = logging.getLogger(name)
|
||||
assert len(uv_logger.handlers) == 1
|
||||
assert uv_logger.level == logging.DEBUG
|
||||
handler = uv_logger.handlers[0]
|
||||
assert handler.formatter is not None
|
||||
assert handler.formatter._fmt == DEFAULT_FORMAT
|
||||
|
||||
def test_returns_configured_logger(self):
|
||||
logger = configure_logging(logger_name="test.return")
|
||||
|
||||
assert isinstance(logger, logging.Logger)
|
||||
assert logger.name == "test.return"
|
||||
|
||||
|
||||
class TestGetLogger:
|
||||
def test_returns_named_logger(self):
|
||||
logger = get_logger("myapp.services")
|
||||
|
||||
assert isinstance(logger, logging.Logger)
|
||||
assert logger.name == "myapp.services"
|
||||
|
||||
def test_returns_root_logger_when_none(self):
|
||||
logger = get_logger(None)
|
||||
|
||||
assert logger is logging.getLogger()
|
||||
|
||||
def test_defaults_to_caller_module_name(self):
|
||||
logger = get_logger()
|
||||
|
||||
assert logger.name == __name__
|
||||
|
||||
def test_same_name_returns_same_logger(self):
|
||||
a = get_logger("myapp")
|
||||
b = get_logger("myapp")
|
||||
|
||||
assert a is b
|
||||
549
tests/test_metrics.py
Normal file
549
tests/test_metrics.py
Normal file
@@ -0,0 +1,549 @@
|
||||
"""Tests for fastapi_toolsets.metrics module."""
|
||||
|
||||
import tempfile
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from prometheus_client import REGISTRY, CollectorRegistry, Counter, Gauge
|
||||
|
||||
from fastapi_toolsets.metrics import Metric, MetricsRegistry, init_metrics
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_prometheus_registry():
|
||||
"""Unregister test collectors from the global registry after each test."""
|
||||
yield
|
||||
collectors = list(REGISTRY._names_to_collectors.values())
|
||||
for collector in collectors:
|
||||
try:
|
||||
REGISTRY.unregister(collector)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class TestMetric:
|
||||
"""Tests for Metric dataclass."""
|
||||
|
||||
def test_default_collect_is_false(self):
|
||||
"""Default collect is False (provider mode)."""
|
||||
definition = Metric(name="test", func=lambda: None)
|
||||
assert definition.collect is False
|
||||
|
||||
def test_collect_true(self):
|
||||
"""Collect can be set to True (collector mode)."""
|
||||
definition = Metric(name="test", func=lambda: None, collect=True)
|
||||
assert definition.collect is True
|
||||
|
||||
|
||||
class TestMetricsRegistry:
|
||||
"""Tests for MetricsRegistry class."""
|
||||
|
||||
def test_register_with_decorator(self):
|
||||
"""Register metric with bare decorator."""
|
||||
registry = MetricsRegistry()
|
||||
|
||||
@registry.register
|
||||
def my_counter():
|
||||
return Counter("test_counter", "A test counter")
|
||||
|
||||
names = [m.name for m in registry.get_all()]
|
||||
assert "my_counter" in names
|
||||
|
||||
def test_register_with_custom_name(self):
|
||||
"""Register metric with custom name."""
|
||||
registry = MetricsRegistry()
|
||||
|
||||
@registry.register(name="custom_name")
|
||||
def my_counter():
|
||||
return Counter("test_counter_2", "A test counter")
|
||||
|
||||
definition = registry.get_all()[0]
|
||||
assert definition.name == "custom_name"
|
||||
|
||||
def test_register_as_collector(self):
|
||||
"""Register metric with collect=True."""
|
||||
registry = MetricsRegistry()
|
||||
|
||||
@registry.register(collect=True)
|
||||
def collect_something():
|
||||
pass
|
||||
|
||||
definition = registry.get_all()[0]
|
||||
assert definition.collect is True
|
||||
|
||||
def test_register_preserves_function(self):
|
||||
"""Decorator returns the original function unchanged."""
|
||||
registry = MetricsRegistry()
|
||||
|
||||
def my_func():
|
||||
return "original"
|
||||
|
||||
result = registry.register(my_func)
|
||||
assert result is my_func
|
||||
assert result() == "original"
|
||||
|
||||
def test_register_parameterized_preserves_function(self):
|
||||
"""Parameterized decorator returns the original function unchanged."""
|
||||
registry = MetricsRegistry()
|
||||
|
||||
def my_func():
|
||||
return "original"
|
||||
|
||||
result = registry.register(name="custom")(my_func)
|
||||
assert result is my_func
|
||||
assert result() == "original"
|
||||
|
||||
def test_get_all(self):
|
||||
"""Get all registered metrics."""
|
||||
registry = MetricsRegistry()
|
||||
|
||||
@registry.register
|
||||
def metric_a():
|
||||
pass
|
||||
|
||||
@registry.register
|
||||
def metric_b():
|
||||
pass
|
||||
|
||||
names = {m.name for m in registry.get_all()}
|
||||
assert names == {"metric_a", "metric_b"}
|
||||
|
||||
def test_get_providers(self):
|
||||
"""Get only provider metrics (collect=False)."""
|
||||
registry = MetricsRegistry()
|
||||
|
||||
@registry.register
|
||||
def provider():
|
||||
pass
|
||||
|
||||
@registry.register(collect=True)
|
||||
def collector():
|
||||
pass
|
||||
|
||||
providers = registry.get_providers()
|
||||
assert len(providers) == 1
|
||||
assert providers[0].name == "provider"
|
||||
|
||||
def test_get_collectors(self):
|
||||
"""Get only collector metrics (collect=True)."""
|
||||
registry = MetricsRegistry()
|
||||
|
||||
@registry.register
|
||||
def provider():
|
||||
pass
|
||||
|
||||
@registry.register(collect=True)
|
||||
def collector():
|
||||
pass
|
||||
|
||||
collectors = registry.get_collectors()
|
||||
assert len(collectors) == 1
|
||||
assert collectors[0].name == "collector"
|
||||
|
||||
def test_register_overwrites_same_name(self):
|
||||
"""Registering with the same name overwrites the previous entry."""
|
||||
registry = MetricsRegistry()
|
||||
|
||||
@registry.register(name="metric")
|
||||
def first():
|
||||
pass
|
||||
|
||||
@registry.register(name="metric")
|
||||
def second():
|
||||
pass
|
||||
|
||||
assert len(registry.get_all()) == 1
|
||||
assert registry.get_all()[0].func is second
|
||||
|
||||
|
||||
class TestGet:
|
||||
"""Tests for MetricsRegistry.get method."""
|
||||
|
||||
def test_get_returns_instance_after_init(self):
|
||||
"""get() returns the metric instance stored by init_metrics."""
|
||||
app = FastAPI()
|
||||
registry = MetricsRegistry()
|
||||
|
||||
@registry.register
|
||||
def my_gauge():
|
||||
return Gauge("get_test_gauge", "A test gauge")
|
||||
|
||||
init_metrics(app, registry)
|
||||
|
||||
instance = registry.get("my_gauge")
|
||||
assert isinstance(instance, Gauge)
|
||||
|
||||
def test_get_raises_for_registered_but_not_initialized(self):
|
||||
"""get() raises KeyError with an informative message when init_metrics was not called."""
|
||||
registry = MetricsRegistry()
|
||||
|
||||
@registry.register
|
||||
def my_counter():
|
||||
return Counter("get_uninit_counter", "A counter")
|
||||
|
||||
with pytest.raises(KeyError, match="not been initialized yet"):
|
||||
registry.get("my_counter")
|
||||
|
||||
def test_get_raises_for_unknown_name(self):
|
||||
"""get() raises KeyError when the metric name is not registered at all."""
|
||||
registry = MetricsRegistry()
|
||||
|
||||
with pytest.raises(KeyError, match="Unknown metric"):
|
||||
registry.get("nonexistent")
|
||||
|
||||
|
||||
class TestIncludeRegistry:
|
||||
"""Tests for MetricsRegistry.include_registry method."""
|
||||
|
||||
def test_include_empty_registry(self):
|
||||
"""Include an empty registry does nothing."""
|
||||
main = MetricsRegistry()
|
||||
other = MetricsRegistry()
|
||||
|
||||
@main.register
|
||||
def metric_a():
|
||||
pass
|
||||
|
||||
main.include_registry(other)
|
||||
assert len(main.get_all()) == 1
|
||||
|
||||
def test_include_registry_adds_metrics(self):
|
||||
"""Include registry adds all metrics from the other registry."""
|
||||
main = MetricsRegistry()
|
||||
other = MetricsRegistry()
|
||||
|
||||
@main.register
|
||||
def metric_a():
|
||||
pass
|
||||
|
||||
@other.register
|
||||
def metric_b():
|
||||
pass
|
||||
|
||||
@other.register
|
||||
def metric_c():
|
||||
pass
|
||||
|
||||
main.include_registry(other)
|
||||
names = {m.name for m in main.get_all()}
|
||||
assert names == {"metric_a", "metric_b", "metric_c"}
|
||||
|
||||
def test_include_registry_preserves_collect_flag(self):
|
||||
"""Include registry preserves the collect flag."""
|
||||
main = MetricsRegistry()
|
||||
other = MetricsRegistry()
|
||||
|
||||
@other.register(collect=True)
|
||||
def collector():
|
||||
pass
|
||||
|
||||
main.include_registry(other)
|
||||
assert main.get_all()[0].collect is True
|
||||
|
||||
def test_include_registry_raises_on_duplicate(self):
|
||||
"""Include registry raises ValueError on duplicate metric names."""
|
||||
main = MetricsRegistry()
|
||||
other = MetricsRegistry()
|
||||
|
||||
@main.register(name="metric")
|
||||
def metric_main():
|
||||
pass
|
||||
|
||||
@other.register(name="metric")
|
||||
def metric_other():
|
||||
pass
|
||||
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
main.include_registry(other)
|
||||
|
||||
def test_include_multiple_registries(self):
|
||||
"""Include multiple registries sequentially."""
|
||||
main = MetricsRegistry()
|
||||
sub1 = MetricsRegistry()
|
||||
sub2 = MetricsRegistry()
|
||||
|
||||
@main.register
|
||||
def base():
|
||||
pass
|
||||
|
||||
@sub1.register
|
||||
def sub1_metric():
|
||||
pass
|
||||
|
||||
@sub2.register
|
||||
def sub2_metric():
|
||||
pass
|
||||
|
||||
main.include_registry(sub1)
|
||||
main.include_registry(sub2)
|
||||
|
||||
names = {m.name for m in main.get_all()}
|
||||
assert names == {"base", "sub1_metric", "sub2_metric"}
|
||||
|
||||
|
||||
class TestInitMetrics:
|
||||
"""Tests for init_metrics function."""
|
||||
|
||||
@pytest.fixture
|
||||
def metrics_client(self):
|
||||
"""Create a FastAPI app with MetricsRegistry and return a TestClient."""
|
||||
app = FastAPI()
|
||||
registry = MetricsRegistry()
|
||||
init_metrics(app, registry)
|
||||
client = TestClient(app)
|
||||
yield client
|
||||
client.close()
|
||||
|
||||
def test_returns_app(self):
|
||||
"""Returns the FastAPI app."""
|
||||
app = FastAPI()
|
||||
registry = MetricsRegistry()
|
||||
result = init_metrics(app, registry)
|
||||
assert result is app
|
||||
|
||||
def test_metrics_endpoint_responds(self, metrics_client):
|
||||
"""The /metrics endpoint returns 200."""
|
||||
response = metrics_client.get("/metrics")
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_metrics_endpoint_content_type(self, metrics_client):
|
||||
"""The /metrics endpoint returns prometheus content type."""
|
||||
response = metrics_client.get("/metrics")
|
||||
assert "text/plain" in response.headers["content-type"]
|
||||
|
||||
def test_custom_path(self):
|
||||
"""Custom path is used for the metrics endpoint."""
|
||||
app = FastAPI()
|
||||
registry = MetricsRegistry()
|
||||
init_metrics(app, registry, path="/custom-metrics")
|
||||
|
||||
client = TestClient(app)
|
||||
assert client.get("/custom-metrics").status_code == 200
|
||||
assert client.get("/metrics").status_code == 404
|
||||
|
||||
def test_providers_called_at_init(self):
|
||||
"""Provider functions are called once at init time."""
|
||||
app = FastAPI()
|
||||
registry = MetricsRegistry()
|
||||
mock = MagicMock()
|
||||
|
||||
@registry.register
|
||||
def my_provider():
|
||||
mock()
|
||||
|
||||
init_metrics(app, registry)
|
||||
|
||||
mock.assert_called_once()
|
||||
|
||||
def test_collectors_called_on_scrape(self):
|
||||
"""Collector functions are called on each scrape."""
|
||||
app = FastAPI()
|
||||
registry = MetricsRegistry()
|
||||
mock = MagicMock()
|
||||
|
||||
@registry.register(collect=True)
|
||||
def my_collector():
|
||||
mock()
|
||||
|
||||
init_metrics(app, registry)
|
||||
|
||||
client = TestClient(app)
|
||||
client.get("/metrics")
|
||||
client.get("/metrics")
|
||||
|
||||
assert mock.call_count == 2
|
||||
|
||||
def test_collectors_not_called_at_init(self):
|
||||
"""Collector functions are not called at init time."""
|
||||
app = FastAPI()
|
||||
registry = MetricsRegistry()
|
||||
mock = MagicMock()
|
||||
|
||||
@registry.register(collect=True)
|
||||
def my_collector():
|
||||
mock()
|
||||
|
||||
init_metrics(app, registry)
|
||||
|
||||
mock.assert_not_called()
|
||||
|
||||
def test_async_collectors_called_on_scrape(self):
|
||||
"""Async collector functions are awaited on each scrape."""
|
||||
app = FastAPI()
|
||||
registry = MetricsRegistry()
|
||||
mock = AsyncMock()
|
||||
|
||||
@registry.register(collect=True)
|
||||
async def my_async_collector():
|
||||
await mock()
|
||||
|
||||
init_metrics(app, registry)
|
||||
|
||||
client = TestClient(app)
|
||||
client.get("/metrics")
|
||||
client.get("/metrics")
|
||||
|
||||
assert mock.call_count == 2
|
||||
|
||||
def test_mixed_sync_and_async_collectors(self):
|
||||
"""Both sync and async collectors are called on scrape."""
|
||||
app = FastAPI()
|
||||
registry = MetricsRegistry()
|
||||
sync_mock = MagicMock()
|
||||
async_mock = AsyncMock()
|
||||
|
||||
@registry.register(collect=True)
|
||||
def sync_collector():
|
||||
sync_mock()
|
||||
|
||||
@registry.register(collect=True)
|
||||
async def async_collector():
|
||||
await async_mock()
|
||||
|
||||
init_metrics(app, registry)
|
||||
|
||||
client = TestClient(app)
|
||||
client.get("/metrics")
|
||||
|
||||
sync_mock.assert_called_once()
|
||||
async_mock.assert_called_once()
|
||||
|
||||
def test_registered_metrics_appear_in_output(self):
|
||||
"""Metrics created by providers appear in /metrics output."""
|
||||
app = FastAPI()
|
||||
registry = MetricsRegistry()
|
||||
|
||||
@registry.register
|
||||
def my_gauge():
|
||||
g = Gauge("test_gauge_value", "A test gauge")
|
||||
g.set(42)
|
||||
return g
|
||||
|
||||
init_metrics(app, registry)
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/metrics")
|
||||
|
||||
assert b"test_gauge_value" in response.content
|
||||
assert b"42.0" in response.content
|
||||
|
||||
def test_endpoint_not_in_openapi_schema(self):
|
||||
"""The /metrics endpoint is not included in the OpenAPI schema."""
|
||||
app = FastAPI()
|
||||
registry = MetricsRegistry()
|
||||
init_metrics(app, registry)
|
||||
|
||||
schema = app.openapi()
|
||||
assert "/metrics" not in schema.get("paths", {})
|
||||
|
||||
|
||||
class TestMultiProcessMode:
|
||||
"""Tests for multi-process Prometheus mode."""
|
||||
|
||||
def test_multiprocess_with_env_var(self, monkeypatch):
|
||||
"""Multi-process mode works when PROMETHEUS_MULTIPROC_DIR is set."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
monkeypatch.setenv("PROMETHEUS_MULTIPROC_DIR", tmpdir)
|
||||
# Use a separate registry to avoid conflicts with default
|
||||
prom_registry = CollectorRegistry()
|
||||
app = FastAPI()
|
||||
registry = MetricsRegistry()
|
||||
|
||||
@registry.register
|
||||
def mp_counter():
|
||||
return Counter(
|
||||
"mp_test_counter",
|
||||
"A multiprocess counter",
|
||||
registry=prom_registry,
|
||||
)
|
||||
|
||||
init_metrics(app, registry)
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/metrics")
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_single_process_without_env_var(self, monkeypatch):
|
||||
"""Single-process mode when PROMETHEUS_MULTIPROC_DIR is not set."""
|
||||
monkeypatch.delenv("PROMETHEUS_MULTIPROC_DIR", raising=False)
|
||||
|
||||
app = FastAPI()
|
||||
registry = MetricsRegistry()
|
||||
|
||||
@registry.register
|
||||
def sp_gauge():
|
||||
g = Gauge("sp_test_gauge", "A single-process gauge")
|
||||
g.set(99)
|
||||
return g
|
||||
|
||||
init_metrics(app, registry)
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/metrics")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert b"sp_test_gauge" in response.content
|
||||
|
||||
|
||||
class TestMetricsIntegration:
|
||||
"""Integration tests for the metrics module."""
|
||||
|
||||
def test_full_workflow(self):
|
||||
"""Full workflow: registry, providers, collectors, endpoint."""
|
||||
app = FastAPI()
|
||||
registry = MetricsRegistry()
|
||||
call_count = {"value": 0}
|
||||
|
||||
@registry.register
|
||||
def request_counter():
|
||||
return Counter(
|
||||
"integration_requests_total",
|
||||
"Total requests",
|
||||
["method"],
|
||||
)
|
||||
|
||||
@registry.register(collect=True)
|
||||
def collect_uptime():
|
||||
call_count["value"] += 1
|
||||
|
||||
init_metrics(app, registry)
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/metrics")
|
||||
assert response.status_code == 200
|
||||
assert b"integration_requests_total" in response.content
|
||||
assert call_count["value"] == 1
|
||||
|
||||
response = client.get("/metrics")
|
||||
assert call_count["value"] == 2
|
||||
|
||||
def test_multiple_registries_merged(self):
|
||||
"""Multiple registries can be merged and used together."""
|
||||
app = FastAPI()
|
||||
main = MetricsRegistry()
|
||||
sub = MetricsRegistry()
|
||||
|
||||
@main.register
|
||||
def main_gauge():
|
||||
g = Gauge("main_gauge_val", "Main gauge")
|
||||
g.set(1)
|
||||
return g
|
||||
|
||||
@sub.register
|
||||
def sub_gauge():
|
||||
g = Gauge("sub_gauge_val", "Sub gauge")
|
||||
g.set(2)
|
||||
return g
|
||||
|
||||
main.include_registry(sub)
|
||||
init_metrics(app, main)
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/metrics")
|
||||
|
||||
assert b"main_gauge_val" in response.content
|
||||
assert b"sub_gauge_val" in response.content
|
||||
1281
tests/test_models.py
Normal file
1281
tests/test_models.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,44 +1,72 @@
|
||||
"""Tests for fastapi_toolsets.pytest module."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Depends, FastAPI
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from fastapi_toolsets.fixtures import Context, FixtureRegistry
|
||||
from fastapi_toolsets.pytest import (
|
||||
create_async_client,
|
||||
create_db_session,
|
||||
create_worker_database,
|
||||
register_fixtures,
|
||||
worker_database_url,
|
||||
)
|
||||
from fastapi_toolsets.pytest.utils import _get_xdist_worker
|
||||
|
||||
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
|
||||
|
||||
test_registry = FixtureRegistry()
|
||||
|
||||
# Fixed UUIDs for test fixtures to allow consistent assertions
|
||||
ROLE_ADMIN_ID = uuid.UUID("00000000-0000-0000-0000-000000001000")
|
||||
ROLE_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000001001")
|
||||
USER_ADMIN_ID = uuid.UUID("00000000-0000-0000-0000-000000002000")
|
||||
USER_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000002001")
|
||||
USER_EXTRA_ID = uuid.UUID("00000000-0000-0000-0000-000000002002")
|
||||
|
||||
|
||||
@test_registry.register(contexts=[Context.BASE])
|
||||
def roles() -> list[Role]:
|
||||
return [
|
||||
Role(id=1000, name="plugin_admin"),
|
||||
Role(id=1001, name="plugin_user"),
|
||||
Role(id=ROLE_ADMIN_ID, name="plugin_admin"),
|
||||
Role(id=ROLE_USER_ID, name="plugin_user"),
|
||||
]
|
||||
|
||||
|
||||
@test_registry.register(depends_on=["roles"], contexts=[Context.BASE])
|
||||
def users() -> list[User]:
|
||||
return [
|
||||
User(id=1000, username="plugin_admin", email="padmin@test.com", role_id=1000),
|
||||
User(id=1001, username="plugin_user", email="puser@test.com", role_id=1001),
|
||||
User(
|
||||
id=USER_ADMIN_ID,
|
||||
username="plugin_admin",
|
||||
email="padmin@test.com",
|
||||
role_id=ROLE_ADMIN_ID,
|
||||
),
|
||||
User(
|
||||
id=USER_USER_ID,
|
||||
username="plugin_user",
|
||||
email="puser@test.com",
|
||||
role_id=ROLE_USER_ID,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@test_registry.register(depends_on=["users"], contexts=[Context.TESTING])
|
||||
def extra_users() -> list[User]:
|
||||
return [
|
||||
User(id=1002, username="plugin_extra", email="pextra@test.com", role_id=1001),
|
||||
User(
|
||||
id=USER_EXTRA_ID,
|
||||
username="plugin_extra",
|
||||
email="pextra@test.com",
|
||||
role_id=ROLE_USER_ID,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -73,7 +101,7 @@ class TestGeneratedFixtures:
|
||||
assert fixture_roles[1].name == "plugin_user"
|
||||
|
||||
# Verify data is in database
|
||||
count = await RoleCrud.count(db_session, [Role.id >= 1000])
|
||||
count = await RoleCrud.count(db_session)
|
||||
assert count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -86,11 +114,11 @@ class TestGeneratedFixtures:
|
||||
assert len(fixture_users) == 2
|
||||
|
||||
# Roles should also be in database
|
||||
roles_count = await RoleCrud.count(db_session, [Role.id >= 1000])
|
||||
roles_count = await RoleCrud.count(db_session)
|
||||
assert roles_count == 2
|
||||
|
||||
# Users should be in database
|
||||
users_count = await UserCrud.count(db_session, [User.id >= 1000])
|
||||
users_count = await UserCrud.count(db_session)
|
||||
assert users_count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -100,7 +128,7 @@ class TestGeneratedFixtures:
|
||||
"""Fixture returns actual model instances."""
|
||||
user = fixture_users[0]
|
||||
assert isinstance(user, User)
|
||||
assert user.id == 1000
|
||||
assert user.id == USER_ADMIN_ID
|
||||
assert user.username == "plugin_admin"
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -111,7 +139,7 @@ class TestGeneratedFixtures:
|
||||
# Load user with role relationship
|
||||
user = await UserCrud.get(
|
||||
db_session,
|
||||
[User.id == 1000],
|
||||
[User.id == USER_ADMIN_ID],
|
||||
load_options=[selectinload(User.role)],
|
||||
)
|
||||
|
||||
@@ -127,8 +155,8 @@ class TestGeneratedFixtures:
|
||||
assert len(fixture_extra_users) == 1
|
||||
|
||||
# All fixtures should be loaded
|
||||
roles_count = await RoleCrud.count(db_session, [Role.id >= 1000])
|
||||
users_count = await UserCrud.count(db_session, [User.id >= 1000])
|
||||
roles_count = await RoleCrud.count(db_session)
|
||||
users_count = await UserCrud.count(db_session)
|
||||
|
||||
assert roles_count == 2
|
||||
assert users_count == 3 # 2 from users + 1 from extra_users
|
||||
@@ -141,8 +169,7 @@ class TestGeneratedFixtures:
|
||||
# Get all users loaded by fixture
|
||||
users = await UserCrud.get_multi(
|
||||
db_session,
|
||||
filters=[User.id >= 1000],
|
||||
order_by=User.id,
|
||||
order_by=User.username,
|
||||
)
|
||||
|
||||
assert len(users) == 2
|
||||
@@ -161,8 +188,8 @@ class TestGeneratedFixtures:
|
||||
assert len(fixture_users) == 2
|
||||
|
||||
# Both should be in database
|
||||
roles = await RoleCrud.get_multi(db_session, filters=[Role.id >= 1000])
|
||||
users = await UserCrud.get_multi(db_session, filters=[User.id >= 1000])
|
||||
roles = await RoleCrud.get_multi(db_session)
|
||||
users = await UserCrud.get_multi(db_session)
|
||||
|
||||
assert len(roles) == 2
|
||||
assert len(users) == 2
|
||||
@@ -208,6 +235,30 @@ class TestCreateAsyncClient:
|
||||
|
||||
assert client_ref.is_closed
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_dependency_overrides_applied_and_cleaned(self):
|
||||
"""Dependency overrides are applied during the context and removed after."""
|
||||
app = FastAPI()
|
||||
|
||||
async def original_dep() -> str:
|
||||
return "original"
|
||||
|
||||
async def override_dep() -> str:
|
||||
return "overridden"
|
||||
|
||||
@app.get("/dep")
|
||||
async def dep_endpoint(value: str = Depends(original_dep)):
|
||||
return {"value": value}
|
||||
|
||||
async with create_async_client(
|
||||
app, dependency_overrides={original_dep: override_dep}
|
||||
) as client:
|
||||
response = await client.get("/dep")
|
||||
assert response.json() == {"value": "overridden"}
|
||||
|
||||
# Overrides should be cleaned up
|
||||
assert original_dep not in app.dependency_overrides
|
||||
|
||||
|
||||
class TestCreateDbSession:
|
||||
"""Tests for create_db_session helper."""
|
||||
@@ -215,14 +266,15 @@ class TestCreateDbSession:
|
||||
@pytest.mark.anyio
|
||||
async def test_creates_working_session(self):
|
||||
"""Session can perform database operations."""
|
||||
role_id = uuid.uuid4()
|
||||
async with create_db_session(DATABASE_URL, Base) as session:
|
||||
assert isinstance(session, AsyncSession)
|
||||
|
||||
role = Role(id=9001, name="test_helper_role")
|
||||
role = Role(id=role_id, name="test_helper_role")
|
||||
session.add(role)
|
||||
await session.commit()
|
||||
|
||||
result = await session.execute(select(Role).where(Role.id == 9001))
|
||||
result = await session.execute(select(Role).where(Role.id == role_id))
|
||||
fetched = result.scalar_one()
|
||||
assert fetched.name == "test_helper_role"
|
||||
|
||||
@@ -237,8 +289,9 @@ class TestCreateDbSession:
|
||||
@pytest.mark.anyio
|
||||
async def test_tables_dropped_after_session(self):
|
||||
"""Tables are dropped after session closes when drop_tables=True."""
|
||||
role_id = uuid.uuid4()
|
||||
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||
role = Role(id=9002, name="will_be_dropped")
|
||||
role = Role(id=role_id, name="will_be_dropped")
|
||||
session.add(role)
|
||||
await session.commit()
|
||||
|
||||
@@ -250,14 +303,15 @@ class TestCreateDbSession:
|
||||
@pytest.mark.anyio
|
||||
async def test_tables_preserved_when_drop_disabled(self):
|
||||
"""Tables are preserved when drop_tables=False."""
|
||||
role_id = uuid.uuid4()
|
||||
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
|
||||
role = Role(id=9003, name="preserved_role")
|
||||
role = Role(id=role_id, name="preserved_role")
|
||||
session.add(role)
|
||||
await session.commit()
|
||||
|
||||
# Create another session without dropping
|
||||
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
|
||||
result = await session.execute(select(Role).where(Role.id == 9003))
|
||||
result = await session.execute(select(Role).where(Role.id == role_id))
|
||||
fetched = result.scalar_one_or_none()
|
||||
assert fetched is not None
|
||||
assert fetched.name == "preserved_role"
|
||||
@@ -265,3 +319,163 @@ class TestCreateDbSession:
|
||||
# Cleanup: drop tables manually
|
||||
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _:
|
||||
pass
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cleanup_truncates_tables(self):
|
||||
"""Tables are truncated after session closes when cleanup=True."""
|
||||
role_id = uuid.uuid4()
|
||||
async with create_db_session(
|
||||
DATABASE_URL, Base, cleanup=True, drop_tables=False
|
||||
) as session:
|
||||
role = Role(id=role_id, name="will_be_cleaned")
|
||||
session.add(role)
|
||||
await session.commit()
|
||||
|
||||
# Data should have been truncated, but tables still exist
|
||||
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||
result = await session.execute(select(Role))
|
||||
assert result.all() == []
|
||||
|
||||
|
||||
class TestGetXdistWorker:
|
||||
"""Tests for _get_xdist_worker helper."""
|
||||
|
||||
def test_returns_default_test_db_without_env_var(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""Returns default_test_db when PYTEST_XDIST_WORKER is not set."""
|
||||
monkeypatch.delenv("PYTEST_XDIST_WORKER", raising=False)
|
||||
assert _get_xdist_worker("my_default") == "my_default"
|
||||
|
||||
def test_returns_worker_name(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Returns the worker name from the environment variable."""
|
||||
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw0")
|
||||
assert _get_xdist_worker("ignored") == "gw0"
|
||||
|
||||
|
||||
class TestWorkerDatabaseUrl:
|
||||
"""Tests for worker_database_url helper."""
|
||||
|
||||
def test_appends_default_test_db_without_xdist(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""default_test_db is appended when not running under xdist."""
|
||||
monkeypatch.delenv("PYTEST_XDIST_WORKER", raising=False)
|
||||
url = "postgresql+asyncpg://user:pass@localhost:5432/mydb"
|
||||
result = worker_database_url(url, default_test_db="fallback")
|
||||
assert make_url(result).database == "mydb_fallback"
|
||||
|
||||
def test_appends_worker_id_to_database_name(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Worker name is appended to the database name."""
|
||||
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw0")
|
||||
url = "postgresql+asyncpg://user:pass@localhost:5432/db"
|
||||
result = worker_database_url(url, default_test_db="unused")
|
||||
assert make_url(result).database == "db_gw0"
|
||||
|
||||
def test_preserves_url_components(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Host, port, username, password, and driver are preserved."""
|
||||
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw2")
|
||||
url = "postgresql+asyncpg://myuser:secret@dbhost:6543/testdb"
|
||||
result = make_url(worker_database_url(url, default_test_db="unused"))
|
||||
|
||||
assert result.drivername == "postgresql+asyncpg"
|
||||
assert result.username == "myuser"
|
||||
assert result.password == "secret"
|
||||
assert result.host == "dbhost"
|
||||
assert result.port == 6543
|
||||
assert result.database == "testdb_gw2"
|
||||
|
||||
|
||||
class TestCreateWorkerDatabase:
|
||||
"""Tests for create_worker_database context manager."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_creates_default_db_without_xdist(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""Without xdist, creates a database suffixed with default_test_db."""
|
||||
monkeypatch.delenv("PYTEST_XDIST_WORKER", raising=False)
|
||||
default_test_db = "no_xdist_default"
|
||||
expected_db = make_url(
|
||||
worker_database_url(DATABASE_URL, default_test_db=default_test_db)
|
||||
).database
|
||||
|
||||
async with create_worker_database(
|
||||
DATABASE_URL, default_test_db=default_test_db
|
||||
) as url:
|
||||
assert make_url(url).database == expected_db
|
||||
|
||||
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||
async with engine.connect() as conn:
|
||||
result = await conn.execute(
|
||||
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||
{"name": expected_db},
|
||||
)
|
||||
assert result.scalar() == 1
|
||||
await engine.dispose()
|
||||
|
||||
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||
async with engine.connect() as conn:
|
||||
result = await conn.execute(
|
||||
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||
{"name": expected_db},
|
||||
)
|
||||
assert result.scalar() is None
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_creates_and_drops_worker_database(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""Worker database exists inside the context and is dropped after."""
|
||||
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw_test_create")
|
||||
expected_db = make_url(
|
||||
worker_database_url(DATABASE_URL, default_test_db="unused")
|
||||
).database
|
||||
|
||||
async with create_worker_database(DATABASE_URL) as url:
|
||||
assert make_url(url).database == expected_db
|
||||
|
||||
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||
async with engine.connect() as conn:
|
||||
result = await conn.execute(
|
||||
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||
{"name": expected_db},
|
||||
)
|
||||
assert result.scalar() == 1
|
||||
await engine.dispose()
|
||||
|
||||
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||
async with engine.connect() as conn:
|
||||
result = await conn.execute(
|
||||
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||
{"name": expected_db},
|
||||
)
|
||||
assert result.scalar() is None
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cleans_up_stale_database(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""A pre-existing worker database is dropped and recreated."""
|
||||
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw_test_stale")
|
||||
expected_db = make_url(
|
||||
worker_database_url(DATABASE_URL, default_test_db="unused")
|
||||
).database
|
||||
|
||||
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||
async with engine.connect() as conn:
|
||||
await conn.execute(text(f"DROP DATABASE IF EXISTS {expected_db}"))
|
||||
await conn.execute(text(f"CREATE DATABASE {expected_db}"))
|
||||
await engine.dispose()
|
||||
|
||||
async with create_worker_database(DATABASE_URL) as url:
|
||||
assert make_url(url).database == expected_db
|
||||
|
||||
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||
async with engine.connect() as conn:
|
||||
result = await conn.execute(
|
||||
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||
{"name": expected_db},
|
||||
)
|
||||
assert result.scalar() is None
|
||||
await engine.dispose()
|
||||
|
||||
@@ -5,9 +5,13 @@ from pydantic import ValidationError
|
||||
|
||||
from fastapi_toolsets.schemas import (
|
||||
ApiError,
|
||||
CursorPagination,
|
||||
CursorPaginatedResponse,
|
||||
ErrorResponse,
|
||||
OffsetPagination,
|
||||
OffsetPaginatedResponse,
|
||||
PaginatedResponse,
|
||||
Pagination,
|
||||
PaginationType,
|
||||
Response,
|
||||
ResponseStatus,
|
||||
)
|
||||
@@ -46,6 +50,31 @@ class TestApiError:
|
||||
assert error.desc == "The resource was not found."
|
||||
assert error.err_code == "RES-404"
|
||||
|
||||
def test_data_defaults_to_none(self):
|
||||
"""ApiError data field defaults to None."""
|
||||
error = ApiError(
|
||||
code=404,
|
||||
msg="Not Found",
|
||||
desc="The resource was not found.",
|
||||
err_code="RES-404",
|
||||
)
|
||||
assert error.data is None
|
||||
|
||||
def test_create_with_data(self):
|
||||
"""ApiError can be created with a data payload."""
|
||||
error = ApiError(
|
||||
code=422,
|
||||
msg="Validation Error",
|
||||
desc="2 validation error(s) detected",
|
||||
err_code="VAL-422",
|
||||
data={
|
||||
"errors": [{"field": "name", "message": "required", "type": "missing"}]
|
||||
},
|
||||
)
|
||||
assert error.data == {
|
||||
"errors": [{"field": "name", "message": "required", "type": "missing"}]
|
||||
}
|
||||
|
||||
def test_requires_all_fields(self):
|
||||
"""ApiError requires all fields."""
|
||||
with pytest.raises(ValidationError):
|
||||
@@ -129,12 +158,12 @@ class TestErrorResponse:
|
||||
assert data["description"] == "Details"
|
||||
|
||||
|
||||
class TestPagination:
|
||||
"""Tests for Pagination schema."""
|
||||
class TestOffsetPagination:
|
||||
"""Tests for OffsetPagination schema (canonical name for offset-based pagination)."""
|
||||
|
||||
def test_create_pagination(self):
|
||||
"""Create Pagination with all fields."""
|
||||
pagination = Pagination(
|
||||
"""Create OffsetPagination with all fields."""
|
||||
pagination = OffsetPagination(
|
||||
total_count=100,
|
||||
items_per_page=10,
|
||||
page=1,
|
||||
@@ -148,7 +177,7 @@ class TestPagination:
|
||||
|
||||
def test_last_page_has_more_false(self):
|
||||
"""Last page has has_more=False."""
|
||||
pagination = Pagination(
|
||||
pagination = OffsetPagination(
|
||||
total_count=25,
|
||||
items_per_page=10,
|
||||
page=3,
|
||||
@@ -158,8 +187,8 @@ class TestPagination:
|
||||
assert pagination.has_more is False
|
||||
|
||||
def test_serialization(self):
|
||||
"""Pagination serializes correctly."""
|
||||
pagination = Pagination(
|
||||
"""OffsetPagination serializes correctly."""
|
||||
pagination = OffsetPagination(
|
||||
total_count=50,
|
||||
items_per_page=20,
|
||||
page=2,
|
||||
@@ -172,13 +201,152 @@ class TestPagination:
|
||||
assert data["page"] == 2
|
||||
assert data["has_more"] is True
|
||||
|
||||
def test_total_count_can_be_none(self):
|
||||
"""total_count accepts None (include_total=False mode)."""
|
||||
pagination = OffsetPagination(
|
||||
total_count=None,
|
||||
items_per_page=20,
|
||||
page=1,
|
||||
has_more=True,
|
||||
)
|
||||
assert pagination.total_count is None
|
||||
|
||||
def test_serialization_with_none_total_count(self):
|
||||
"""OffsetPagination serializes total_count=None correctly."""
|
||||
pagination = OffsetPagination(
|
||||
total_count=None,
|
||||
items_per_page=20,
|
||||
page=1,
|
||||
has_more=False,
|
||||
)
|
||||
data = pagination.model_dump()
|
||||
assert data["total_count"] is None
|
||||
|
||||
def test_pages_computed(self):
|
||||
"""pages is ceil(total_count / items_per_page)."""
|
||||
pagination = OffsetPagination(
|
||||
total_count=42,
|
||||
items_per_page=10,
|
||||
page=1,
|
||||
has_more=True,
|
||||
)
|
||||
assert pagination.pages == 5
|
||||
|
||||
def test_pages_exact_division(self):
|
||||
"""pages is exact when total_count is evenly divisible."""
|
||||
pagination = OffsetPagination(
|
||||
total_count=40,
|
||||
items_per_page=10,
|
||||
page=1,
|
||||
has_more=False,
|
||||
)
|
||||
assert pagination.pages == 4
|
||||
|
||||
def test_pages_zero_total(self):
|
||||
"""pages is 0 when total_count is 0."""
|
||||
pagination = OffsetPagination(
|
||||
total_count=0,
|
||||
items_per_page=10,
|
||||
page=1,
|
||||
has_more=False,
|
||||
)
|
||||
assert pagination.pages == 0
|
||||
|
||||
def test_pages_zero_items_per_page(self):
|
||||
"""pages is 0 when items_per_page is 0."""
|
||||
pagination = OffsetPagination(
|
||||
total_count=100,
|
||||
items_per_page=0,
|
||||
page=1,
|
||||
has_more=False,
|
||||
)
|
||||
assert pagination.pages == 0
|
||||
|
||||
def test_pages_none_when_total_count_none(self):
|
||||
"""pages is None when total_count is None (include_total=False)."""
|
||||
pagination = OffsetPagination(
|
||||
total_count=None,
|
||||
items_per_page=20,
|
||||
page=1,
|
||||
has_more=True,
|
||||
)
|
||||
assert pagination.pages is None
|
||||
|
||||
def test_pages_in_serialization(self):
|
||||
"""pages appears in model_dump output."""
|
||||
pagination = OffsetPagination(
|
||||
total_count=25,
|
||||
items_per_page=10,
|
||||
page=1,
|
||||
has_more=True,
|
||||
)
|
||||
data = pagination.model_dump()
|
||||
assert data["pages"] == 3
|
||||
|
||||
|
||||
class TestCursorPagination:
|
||||
"""Tests for CursorPagination schema."""
|
||||
|
||||
def test_create_with_next_cursor(self):
|
||||
"""CursorPagination with a next cursor indicates more pages."""
|
||||
pagination = CursorPagination(
|
||||
next_cursor="eyJ2YWx1ZSI6ICIxMjMifQ==",
|
||||
items_per_page=20,
|
||||
has_more=True,
|
||||
)
|
||||
assert pagination.next_cursor == "eyJ2YWx1ZSI6ICIxMjMifQ=="
|
||||
assert pagination.prev_cursor is None
|
||||
assert pagination.items_per_page == 20
|
||||
assert pagination.has_more is True
|
||||
|
||||
def test_create_last_page(self):
|
||||
"""CursorPagination for the last page has next_cursor=None and has_more=False."""
|
||||
pagination = CursorPagination(
|
||||
next_cursor=None,
|
||||
items_per_page=20,
|
||||
has_more=False,
|
||||
)
|
||||
assert pagination.next_cursor is None
|
||||
assert pagination.has_more is False
|
||||
|
||||
def test_prev_cursor_defaults_to_none(self):
|
||||
"""prev_cursor defaults to None."""
|
||||
pagination = CursorPagination(
|
||||
next_cursor=None, items_per_page=10, has_more=False
|
||||
)
|
||||
assert pagination.prev_cursor is None
|
||||
|
||||
def test_prev_cursor_can_be_set(self):
|
||||
"""prev_cursor can be explicitly set."""
|
||||
pagination = CursorPagination(
|
||||
next_cursor="next123",
|
||||
prev_cursor="prev456",
|
||||
items_per_page=10,
|
||||
has_more=True,
|
||||
)
|
||||
assert pagination.prev_cursor == "prev456"
|
||||
|
||||
def test_serialization(self):
|
||||
"""CursorPagination serializes correctly."""
|
||||
pagination = CursorPagination(
|
||||
next_cursor="abc123",
|
||||
prev_cursor="xyz789",
|
||||
items_per_page=20,
|
||||
has_more=True,
|
||||
)
|
||||
data = pagination.model_dump()
|
||||
assert data["next_cursor"] == "abc123"
|
||||
assert data["prev_cursor"] == "xyz789"
|
||||
assert data["items_per_page"] == 20
|
||||
assert data["has_more"] is True
|
||||
|
||||
|
||||
class TestPaginatedResponse:
|
||||
"""Tests for PaginatedResponse schema."""
|
||||
|
||||
def test_create_paginated_response(self):
|
||||
"""Create PaginatedResponse with data and pagination."""
|
||||
pagination = Pagination(
|
||||
pagination = OffsetPagination(
|
||||
total_count=30,
|
||||
items_per_page=10,
|
||||
page=1,
|
||||
@@ -189,13 +357,14 @@ class TestPaginatedResponse:
|
||||
pagination=pagination,
|
||||
)
|
||||
|
||||
assert isinstance(response.pagination, OffsetPagination)
|
||||
assert len(response.data) == 2
|
||||
assert response.pagination.total_count == 30
|
||||
assert response.status == ResponseStatus.SUCCESS
|
||||
|
||||
def test_with_custom_message(self):
|
||||
"""PaginatedResponse with custom message."""
|
||||
pagination = Pagination(
|
||||
pagination = OffsetPagination(
|
||||
total_count=5,
|
||||
items_per_page=10,
|
||||
page=1,
|
||||
@@ -211,28 +380,48 @@ class TestPaginatedResponse:
|
||||
|
||||
def test_empty_data(self):
|
||||
"""PaginatedResponse with empty data."""
|
||||
pagination = Pagination(
|
||||
pagination = OffsetPagination(
|
||||
total_count=0,
|
||||
items_per_page=10,
|
||||
page=1,
|
||||
has_more=False,
|
||||
)
|
||||
response = PaginatedResponse[dict](
|
||||
response = PaginatedResponse(
|
||||
data=[],
|
||||
pagination=pagination,
|
||||
)
|
||||
|
||||
assert isinstance(response.pagination, OffsetPagination)
|
||||
assert response.data == []
|
||||
assert response.pagination.total_count == 0
|
||||
|
||||
def test_class_getitem_with_concrete_type_returns_discriminated_union(self):
|
||||
"""PaginatedResponse[T] with a concrete type returns a discriminated Annotated union."""
|
||||
import typing
|
||||
|
||||
alias = PaginatedResponse[dict]
|
||||
args = typing.get_args(alias)
|
||||
# args[0] is the Union, args[1] is the FieldInfo discriminator
|
||||
union_args = typing.get_args(args[0])
|
||||
assert CursorPaginatedResponse[dict] in union_args
|
||||
assert OffsetPaginatedResponse[dict] in union_args
|
||||
|
||||
def test_class_getitem_is_cached(self):
|
||||
"""Repeated subscripting with the same type returns the identical cached object."""
|
||||
assert PaginatedResponse[dict] is PaginatedResponse[dict]
|
||||
|
||||
def test_class_getitem_with_typevar_returns_generic(self):
|
||||
"""PaginatedResponse[TypeVar] falls through to Pydantic generic parametrisation."""
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
alias = PaginatedResponse[T]
|
||||
# Should be a generic alias, not an Annotated union
|
||||
assert not hasattr(alias, "__metadata__")
|
||||
|
||||
def test_generic_type_hint(self):
|
||||
"""PaginatedResponse supports generic type hints."""
|
||||
|
||||
class UserOut:
|
||||
id: int
|
||||
name: str
|
||||
|
||||
pagination = Pagination(
|
||||
pagination = OffsetPagination(
|
||||
total_count=1,
|
||||
items_per_page=10,
|
||||
page=1,
|
||||
@@ -247,7 +436,7 @@ class TestPaginatedResponse:
|
||||
|
||||
def test_serialization(self):
|
||||
"""PaginatedResponse serializes correctly."""
|
||||
pagination = Pagination(
|
||||
pagination = OffsetPagination(
|
||||
total_count=100,
|
||||
items_per_page=10,
|
||||
page=5,
|
||||
@@ -265,6 +454,211 @@ class TestPaginatedResponse:
|
||||
assert data["data"] == ["item1", "item2"]
|
||||
assert data["pagination"]["page"] == 5
|
||||
|
||||
def test_pagination_field_accepts_offset_pagination(self):
|
||||
"""PaginatedResponse.pagination accepts OffsetPagination."""
|
||||
response = PaginatedResponse(
|
||||
data=[1, 2],
|
||||
pagination=OffsetPagination(
|
||||
total_count=2, items_per_page=10, page=1, has_more=False
|
||||
),
|
||||
)
|
||||
assert isinstance(response.pagination, OffsetPagination)
|
||||
|
||||
def test_pagination_field_accepts_cursor_pagination(self):
|
||||
"""PaginatedResponse.pagination accepts CursorPagination."""
|
||||
response = PaginatedResponse(
|
||||
data=[1, 2],
|
||||
pagination=CursorPagination(
|
||||
next_cursor=None, items_per_page=10, has_more=False
|
||||
),
|
||||
)
|
||||
assert isinstance(response.pagination, CursorPagination)
|
||||
|
||||
|
||||
class TestPaginationType:
|
||||
"""Tests for PaginationType enum."""
|
||||
|
||||
def test_offset_value(self):
|
||||
"""OFFSET has string value 'offset'."""
|
||||
assert PaginationType.OFFSET == "offset"
|
||||
assert PaginationType.OFFSET.value == "offset"
|
||||
|
||||
def test_cursor_value(self):
|
||||
"""CURSOR has string value 'cursor'."""
|
||||
assert PaginationType.CURSOR == "cursor"
|
||||
assert PaginationType.CURSOR.value == "cursor"
|
||||
|
||||
def test_is_string_enum(self):
|
||||
"""PaginationType is a string enum."""
|
||||
assert isinstance(PaginationType.OFFSET, str)
|
||||
assert isinstance(PaginationType.CURSOR, str)
|
||||
|
||||
def test_members(self):
|
||||
"""PaginationType has exactly two members."""
|
||||
assert set(PaginationType) == {PaginationType.OFFSET, PaginationType.CURSOR}
|
||||
|
||||
|
||||
class TestOffsetPaginatedResponse:
|
||||
"""Tests for OffsetPaginatedResponse schema."""
|
||||
|
||||
def test_pagination_type_is_offset(self):
|
||||
"""pagination_type is always PaginationType.OFFSET."""
|
||||
response = OffsetPaginatedResponse(
|
||||
data=[],
|
||||
pagination=OffsetPagination(
|
||||
total_count=0, items_per_page=10, page=1, has_more=False
|
||||
),
|
||||
)
|
||||
assert response.pagination_type is PaginationType.OFFSET
|
||||
|
||||
def test_pagination_type_serializes_to_string(self):
|
||||
"""pagination_type serializes to 'offset' in JSON mode."""
|
||||
response = OffsetPaginatedResponse(
|
||||
data=[],
|
||||
pagination=OffsetPagination(
|
||||
total_count=0, items_per_page=10, page=1, has_more=False
|
||||
),
|
||||
)
|
||||
assert response.model_dump(mode="json")["pagination_type"] == "offset"
|
||||
|
||||
def test_pagination_field_is_typed(self):
|
||||
"""pagination field is OffsetPagination, not the union."""
|
||||
response = OffsetPaginatedResponse(
|
||||
data=[{"id": 1}],
|
||||
pagination=OffsetPagination(
|
||||
total_count=10, items_per_page=5, page=2, has_more=True
|
||||
),
|
||||
)
|
||||
assert isinstance(response.pagination, OffsetPagination)
|
||||
assert response.pagination.total_count == 10
|
||||
assert response.pagination.page == 2
|
||||
|
||||
def test_is_subclass_of_paginated_response(self):
|
||||
"""OffsetPaginatedResponse IS a PaginatedResponse."""
|
||||
response = OffsetPaginatedResponse(
|
||||
data=[],
|
||||
pagination=OffsetPagination(
|
||||
total_count=0, items_per_page=10, page=1, has_more=False
|
||||
),
|
||||
)
|
||||
assert isinstance(response, PaginatedResponse)
|
||||
|
||||
def test_pagination_type_default_cannot_be_overridden_to_cursor(self):
|
||||
"""pagination_type rejects values other than OFFSET."""
|
||||
with pytest.raises(ValidationError):
|
||||
OffsetPaginatedResponse(
|
||||
data=[],
|
||||
pagination=OffsetPagination(
|
||||
total_count=0, items_per_page=10, page=1, has_more=False
|
||||
),
|
||||
pagination_type=PaginationType.CURSOR, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def test_filter_attributes_defaults_to_none(self):
|
||||
"""filter_attributes defaults to None."""
|
||||
response = OffsetPaginatedResponse(
|
||||
data=[],
|
||||
pagination=OffsetPagination(
|
||||
total_count=0, items_per_page=10, page=1, has_more=False
|
||||
),
|
||||
)
|
||||
assert response.filter_attributes is None
|
||||
|
||||
def test_full_serialization(self):
|
||||
"""Full JSON serialization includes all expected fields."""
|
||||
response = OffsetPaginatedResponse(
|
||||
data=[{"id": 1}],
|
||||
pagination=OffsetPagination(
|
||||
total_count=1, items_per_page=10, page=1, has_more=False
|
||||
),
|
||||
filter_attributes={"status": ["active"]},
|
||||
)
|
||||
data = response.model_dump(mode="json")
|
||||
|
||||
assert data["pagination_type"] == "offset"
|
||||
assert data["status"] == "SUCCESS"
|
||||
assert data["data"] == [{"id": 1}]
|
||||
assert data["pagination"]["total_count"] == 1
|
||||
assert data["filter_attributes"] == {"status": ["active"]}
|
||||
|
||||
|
||||
class TestCursorPaginatedResponse:
|
||||
"""Tests for CursorPaginatedResponse schema."""
|
||||
|
||||
def test_pagination_type_is_cursor(self):
|
||||
"""pagination_type is always PaginationType.CURSOR."""
|
||||
response = CursorPaginatedResponse(
|
||||
data=[],
|
||||
pagination=CursorPagination(
|
||||
next_cursor=None, items_per_page=10, has_more=False
|
||||
),
|
||||
)
|
||||
assert response.pagination_type is PaginationType.CURSOR
|
||||
|
||||
def test_pagination_type_serializes_to_string(self):
|
||||
"""pagination_type serializes to 'cursor' in JSON mode."""
|
||||
response = CursorPaginatedResponse(
|
||||
data=[],
|
||||
pagination=CursorPagination(
|
||||
next_cursor=None, items_per_page=10, has_more=False
|
||||
),
|
||||
)
|
||||
assert response.model_dump(mode="json")["pagination_type"] == "cursor"
|
||||
|
||||
def test_pagination_field_is_typed(self):
|
||||
"""pagination field is CursorPagination, not the union."""
|
||||
response = CursorPaginatedResponse(
|
||||
data=[{"id": 1}],
|
||||
pagination=CursorPagination(
|
||||
next_cursor="abc123",
|
||||
prev_cursor=None,
|
||||
items_per_page=20,
|
||||
has_more=True,
|
||||
),
|
||||
)
|
||||
assert isinstance(response.pagination, CursorPagination)
|
||||
assert response.pagination.next_cursor == "abc123"
|
||||
assert response.pagination.has_more is True
|
||||
|
||||
def test_is_subclass_of_paginated_response(self):
|
||||
"""CursorPaginatedResponse IS a PaginatedResponse."""
|
||||
response = CursorPaginatedResponse(
|
||||
data=[],
|
||||
pagination=CursorPagination(
|
||||
next_cursor=None, items_per_page=10, has_more=False
|
||||
),
|
||||
)
|
||||
assert isinstance(response, PaginatedResponse)
|
||||
|
||||
def test_pagination_type_default_cannot_be_overridden_to_offset(self):
|
||||
"""pagination_type rejects values other than CURSOR."""
|
||||
with pytest.raises(ValidationError):
|
||||
CursorPaginatedResponse(
|
||||
data=[],
|
||||
pagination=CursorPagination(
|
||||
next_cursor=None, items_per_page=10, has_more=False
|
||||
),
|
||||
pagination_type=PaginationType.OFFSET, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def test_full_serialization(self):
|
||||
"""Full JSON serialization includes all expected fields."""
|
||||
response = CursorPaginatedResponse(
|
||||
data=[{"id": 1}],
|
||||
pagination=CursorPagination(
|
||||
next_cursor="tok_next",
|
||||
prev_cursor="tok_prev",
|
||||
items_per_page=10,
|
||||
has_more=True,
|
||||
),
|
||||
)
|
||||
data = response.model_dump(mode="json")
|
||||
|
||||
assert data["pagination_type"] == "cursor"
|
||||
assert data["status"] == "SUCCESS"
|
||||
assert data["pagination"]["next_cursor"] == "tok_next"
|
||||
assert data["pagination"]["prev_cursor"] == "tok_prev"
|
||||
|
||||
|
||||
class TestFromAttributes:
|
||||
"""Tests for from_attributes config (ORM mode)."""
|
||||
|
||||
147
zensical.toml
Normal file
147
zensical.toml
Normal file
@@ -0,0 +1,147 @@
|
||||
[project]
|
||||
site_name = "FastAPI Toolsets"
|
||||
site_description = "Production-ready utilities for FastAPI applications."
|
||||
site_author = "d3vyce"
|
||||
site_url = "https://fastapi-toolsets.d3vyce.fr"
|
||||
copyright = "Copyright © 2026 d3vyce"
|
||||
repo_url = "https://github.com/d3vyce/fastapi-toolsets"
|
||||
|
||||
[project.theme]
|
||||
custom_dir = "docs/overrides"
|
||||
language = "en"
|
||||
features = [
|
||||
"announce.dismiss",
|
||||
"content.action.view",
|
||||
"content.code.annotate",
|
||||
"content.code.copy",
|
||||
"content.code.select",
|
||||
"content.footnote.tooltips",
|
||||
"content.tabs.link",
|
||||
"content.tooltips",
|
||||
"navigation.footer",
|
||||
"navigation.indexes",
|
||||
"navigation.instant",
|
||||
"navigation.instant.prefetch",
|
||||
"navigation.path",
|
||||
"navigation.sections",
|
||||
"navigation.tabs",
|
||||
"navigation.top",
|
||||
"navigation.tracking",
|
||||
"search.highlight",
|
||||
]
|
||||
|
||||
[[project.theme.palette]]
|
||||
scheme = "default"
|
||||
toggle.icon = "lucide/sun"
|
||||
toggle.name = "Switch to dark mode"
|
||||
|
||||
[[project.theme.palette]]
|
||||
scheme = "slate"
|
||||
toggle.icon = "lucide/moon"
|
||||
toggle.name = "Switch to light mode"
|
||||
|
||||
[project.theme.font]
|
||||
text = "Inter"
|
||||
code = "Jetbrains Mono"
|
||||
|
||||
[project.theme.icon]
|
||||
repo = "fontawesome/brands/github"
|
||||
|
||||
[project.plugins.mkdocstrings.handlers.python]
|
||||
inventories = ["https://docs.python.org/3/objects.inv"]
|
||||
paths = ["src"]
|
||||
|
||||
[project.plugins.mkdocstrings.handlers.python.options]
|
||||
docstring_style = "google"
|
||||
inherited_members = true
|
||||
show_source = false
|
||||
show_root_heading = true
|
||||
|
||||
[project.markdown_extensions]
|
||||
abbr = {}
|
||||
admonition = {}
|
||||
attr_list = {}
|
||||
def_list = {}
|
||||
footnotes = {}
|
||||
md_in_html = {}
|
||||
"pymdownx.arithmatex" = {generic = true}
|
||||
"pymdownx.betterem" = {}
|
||||
"pymdownx.caret" = {}
|
||||
"pymdownx.details" = {}
|
||||
"pymdownx.emoji" = {}
|
||||
"pymdownx.inlinehilite" = {}
|
||||
"pymdownx.keys" = {}
|
||||
"pymdownx.magiclink" = {}
|
||||
"pymdownx.mark" = {}
|
||||
"pymdownx.smartsymbols" = {}
|
||||
"pymdownx.tasklist" = {custom_checkbox = true}
|
||||
"pymdownx.tilde" = {}
|
||||
|
||||
[project.markdown_extensions.pymdownx.emoji]
|
||||
emoji_index = "zensical.extensions.emoji.twemoji"
|
||||
emoji_generator = "zensical.extensions.emoji.to_svg"
|
||||
|
||||
[project.markdown_extensions."pymdownx.highlight"]
|
||||
anchor_linenums = true
|
||||
line_spans = "__span"
|
||||
pygments_lang_class = true
|
||||
|
||||
[project.markdown_extensions."pymdownx.superfences"]
|
||||
custom_fences = [{name = "mermaid", class = "mermaid"}]
|
||||
|
||||
[project.markdown_extensions."pymdownx.tabbed"]
|
||||
alternate_style = true
|
||||
combine_header_slug = true
|
||||
|
||||
[project.markdown_extensions."toc"]
|
||||
permalink = true
|
||||
|
||||
[project.markdown_extensions."pymdownx.snippets"]
|
||||
base_path = ["."]
|
||||
check_paths = true
|
||||
|
||||
[[project.nav]]
|
||||
Home = "index.md"
|
||||
|
||||
[[project.nav]]
|
||||
Modules = [
|
||||
{CLI = "module/cli.md"},
|
||||
{CRUD = "module/crud.md"},
|
||||
{Database = "module/db.md"},
|
||||
{Dependencies = "module/dependencies.md"},
|
||||
{Exceptions = "module/exceptions.md"},
|
||||
{Fixtures = "module/fixtures.md"},
|
||||
{Logger = "module/logger.md"},
|
||||
{Metrics = "module/metrics.md"},
|
||||
{Models = "module/models.md"},
|
||||
{Pytest = "module/pytest.md"},
|
||||
{Schemas = "module/schemas.md"},
|
||||
]
|
||||
|
||||
[[project.nav]]
|
||||
Reference = [
|
||||
{CLI = "reference/cli.md"},
|
||||
{CRUD = "reference/crud.md"},
|
||||
{Database = "reference/db.md"},
|
||||
{Dependencies = "reference/dependencies.md"},
|
||||
{Exceptions = "reference/exceptions.md"},
|
||||
{Fixtures = "reference/fixtures.md"},
|
||||
{Logger = "reference/logger.md"},
|
||||
{Metrics = "reference/metrics.md"},
|
||||
{Models = "reference/models.md"},
|
||||
{Pytest = "reference/pytest.md"},
|
||||
{Schemas = "reference/schemas.md"},
|
||||
]
|
||||
|
||||
[[project.nav]]
|
||||
Examples = [
|
||||
{"Pagination & Search" = "examples/pagination-search.md"},
|
||||
]
|
||||
|
||||
[[project.nav]]
|
||||
Migration = [
|
||||
{"v2.0" = "migration/v2.md"},
|
||||
]
|
||||
|
||||
[[project.nav]]
|
||||
"Changelog ↗" = "https://github.com/d3vyce/fastapi-toolsets/releases"
|
||||
Reference in New Issue
Block a user