mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
Compare commits
70 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
56d365d14b
|
|||
|
|
a257d85d45 | ||
|
117675d02f
|
|||
|
|
d7ad7308c5 | ||
|
8d57bf9525
|
|||
|
|
5a08ec2f57 | ||
|
|
433dc55fcd | ||
|
|
0b2abd8c43 | ||
|
|
07c99be89b | ||
|
|
9b75cc7dfc | ||
|
|
6144b383eb | ||
|
7ec407834a
|
|||
|
|
7da34f33a2 | ||
|
8c8911fb27
|
|||
|
|
c0c3b38054 | ||
|
e17d385910
|
|||
|
|
6cf7df55ef | ||
|
|
7482bc5dad | ||
|
|
9d07dfea85 | ||
|
|
31678935aa | ||
|
|
823a0b3e36 | ||
|
1591cd3d64
|
|||
|
|
6714ceeb92 | ||
|
|
73fae04333 | ||
|
|
32ed36e102 | ||
|
|
48567310bc | ||
|
|
de51ed4675 | ||
|
|
794767edbb | ||
|
|
9c136f05bb | ||
|
3299a439fe
|
|||
|
|
d5b22a72fd | ||
|
|
c32f2e18be | ||
|
d971261f98
|
|||
|
|
74a54b7396 | ||
|
|
19805ab376 | ||
|
|
d4498e2063 | ||
| f59c1a17e2 | |||
|
|
8982ba18e3 | ||
|
|
71fe6f478f | ||
|
|
1cfbf14986 | ||
|
|
e3ff535b7e | ||
|
|
8825c772ce | ||
|
|
c8c263ca8f | ||
|
2020fa2f92
|
|||
|
|
1ea316bef4 | ||
|
|
ced1a655f2 | ||
|
|
290b2a06ec | ||
|
|
baa9711665 | ||
|
d526969d0e
|
|||
|
|
e24153053e | ||
|
348ed4c148
|
|||
|
bd6e90de1b
|
|||
|
|
4404fb3df9 | ||
|
|
f68793fbdb | ||
|
|
3a69c3c788 | ||
|
e861a0a49a
|
|||
|
|
cb2cf572e0 | ||
|
494869a172
|
|||
|
|
e0bc93096d | ||
|
1ff94eb9d3
|
|||
|
|
97ab10edcd | ||
|
|
3ff7ff18bb | ||
|
0f50c8a0f0
|
|||
|
|
691fb78fda | ||
|
|
34ef4da317 | ||
|
|
8c287b3ce7 | ||
|
54f5479c24
|
|||
|
|
f467754df1 | ||
|
b57ce40b05
|
|||
|
5264631550
|
2
.github/workflows/build-release.yml
vendored
2
.github/workflows/build-release.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
|||||||
run: uv python install 3.14
|
run: uv python install 3.14
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: uv sync
|
run: uv sync --group dev
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
run: uv build
|
run: uv build
|
||||||
|
|||||||
16
.github/workflows/ci.yml
vendored
16
.github/workflows/ci.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
|||||||
run: uv python install 3.13
|
run: uv python install 3.13
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: uv sync --extra dev
|
run: uv sync --group dev
|
||||||
|
|
||||||
- name: Run Ruff linter
|
- name: Run Ruff linter
|
||||||
run: uv run ruff check .
|
run: uv run ruff check .
|
||||||
@@ -45,7 +45,7 @@ jobs:
|
|||||||
run: uv python install 3.13
|
run: uv python install 3.13
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: uv sync --extra dev
|
run: uv sync --group dev
|
||||||
|
|
||||||
- name: Run ty
|
- name: Run ty
|
||||||
run: uv run ty check
|
run: uv run ty check
|
||||||
@@ -83,18 +83,26 @@ jobs:
|
|||||||
run: uv python install ${{ matrix.python-version }}
|
run: uv python install ${{ matrix.python-version }}
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: uv sync --extra dev
|
run: uv sync --group dev
|
||||||
|
|
||||||
- name: Run tests with coverage
|
- name: Run tests with coverage
|
||||||
env:
|
env:
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@localhost:5432/test_db
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@localhost:5432/test_db
|
||||||
run: |
|
run: |
|
||||||
uv run pytest --cov --cov-report=xml --cov-report=term-missing
|
uv run pytest --cov --cov-report=xml --cov-report=term-missing --junitxml=junit.xml -o junit_family=legacy
|
||||||
|
|
||||||
- name: Upload coverage to Codecov
|
- name: Upload coverage to Codecov
|
||||||
if: matrix.python-version == '3.14'
|
if: matrix.python-version == '3.14'
|
||||||
uses: codecov/codecov-action@v5
|
uses: codecov/codecov-action@v5
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
|
report_type: coverage
|
||||||
files: ./coverage.xml
|
files: ./coverage.xml
|
||||||
fail_ci_if_error: false
|
fail_ci_if_error: false
|
||||||
|
|
||||||
|
- name: Upload test results to Codecov
|
||||||
|
if: matrix.python-version == '3.14'
|
||||||
|
uses: codecov/codecov-action@v5
|
||||||
|
with:
|
||||||
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
|
report_type: test_results
|
||||||
|
|||||||
38
.github/workflows/docs.yml
vendored
Normal file
38
.github/workflows/docs.yml
vendored
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
name: Documentation
|
||||||
|
|
||||||
|
on:
|
||||||
|
release:
|
||||||
|
types: [published]
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
pages: write
|
||||||
|
id-token: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
deploy:
|
||||||
|
environment:
|
||||||
|
name: github-pages
|
||||||
|
url: ${{ steps.deployment.outputs.page_url }}
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/configure-pages@v5
|
||||||
|
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v7
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
run: uv python install 3.13
|
||||||
|
|
||||||
|
- run: uv sync --group dev
|
||||||
|
|
||||||
|
- run: uv run zensical build --clean
|
||||||
|
|
||||||
|
- uses: actions/upload-pages-artifact@v4
|
||||||
|
with:
|
||||||
|
path: site
|
||||||
|
|
||||||
|
- uses: actions/deploy-pages@v4
|
||||||
|
id: deployment
|
||||||
36
README.md
36
README.md
@@ -1,6 +1,6 @@
|
|||||||
# FastAPI Toolsets
|
# FastAPI Toolsets
|
||||||
|
|
||||||
FastAPI Toolsets provides production-ready utilities for FastAPI applications built with async SQLAlchemy and PostgreSQL. It includes generic CRUD operations, a fixture system with dependency resolution, a Django-like CLI, standardized API responses, and structured exception handling with automatic OpenAPI documentation.
|
A modular collection of production-ready utilities for FastAPI. Install only what you need — from async CRUD and database helpers to CLI tooling, Prometheus metrics, and pytest fixtures. Each module is independently installable via optional extras, keeping your dependency footprint minimal.
|
||||||
|
|
||||||
[](https://github.com/d3vyce/fastapi-toolsets/actions/workflows/ci.yml)
|
[](https://github.com/d3vyce/fastapi-toolsets/actions/workflows/ci.yml)
|
||||||
[](https://codecov.io/gh/d3vyce/fastapi-toolsets)
|
[](https://codecov.io/gh/d3vyce/fastapi-toolsets)
|
||||||
@@ -20,17 +20,43 @@ FastAPI Toolsets provides production-ready utilities for FastAPI applications bu
|
|||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
|
The base package includes the core modules (CRUD, database, schemas, exceptions, fixtures, dependencies, logging):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv add fastapi-toolsets
|
uv add fastapi-toolsets
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Install only the extras you need:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add "fastapi-toolsets[cli]" # CLI (typer)
|
||||||
|
uv add "fastapi-toolsets[metrics]" # Prometheus metrics (prometheus_client)
|
||||||
|
uv add "fastapi-toolsets[pytest]" # Pytest helpers (httpx, pytest-xdist)
|
||||||
|
```
|
||||||
|
|
||||||
|
Or install everything:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add "fastapi-toolsets[all]"
|
||||||
|
```
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- **CRUD**: Generic async CRUD operations with `CrudFactory`
|
### Core
|
||||||
- **Fixtures**: Fixture system with dependency management, context support and pytest integration
|
|
||||||
- **CLI**: Django-like command-line interface for fixtures and custom commands
|
- **CRUD**: Generic async CRUD operations with `CrudFactory`, built-in full-text/faceted search and Offset/Cursor pagination.
|
||||||
- **Standardized API Responses**: Consistent response format across your API
|
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
|
||||||
|
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
|
||||||
|
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration
|
||||||
|
- **Standardized API Responses**: Consistent response format with `Response`, `PaginatedResponse`, and `PydanticBase`
|
||||||
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
|
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
|
||||||
|
- **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger`
|
||||||
|
|
||||||
|
### Optional
|
||||||
|
|
||||||
|
- **CLI**: Django-like command-line interface with fixture management and custom commands support
|
||||||
|
- **Metrics**: Prometheus metrics endpoint with provider/collector registry
|
||||||
|
- **Pytest Helpers**: Async test client, database session management, `pytest-xdist` support, and table cleanup utilities
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
|
|||||||
134
docs/examples/pagination-search.md
Normal file
134
docs/examples/pagination-search.md
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
# Pagination & search
|
||||||
|
|
||||||
|
This example builds an articles listing endpoint that supports **offset pagination**, **cursor pagination**, **full-text search**, **faceted filtering**, and **sorting** — all from a single `CrudFactory` definition.
|
||||||
|
|
||||||
|
## Models
|
||||||
|
|
||||||
|
```python title="models.py"
|
||||||
|
--8<-- "docs_src/examples/pagination_search/models.py"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Schemas
|
||||||
|
|
||||||
|
```python title="schemas.py"
|
||||||
|
--8<-- "docs_src/examples/pagination_search/schemas.py"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Crud
|
||||||
|
|
||||||
|
Declare `searchable_fields`, `facet_fields`, and `order_fields` once on [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory). All endpoints built from this class share the same defaults and can override them per call.
|
||||||
|
|
||||||
|
```python title="crud.py"
|
||||||
|
--8<-- "docs_src/examples/pagination_search/crud.py"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Session dependency
|
||||||
|
|
||||||
|
```python title="db.py"
|
||||||
|
--8<-- "docs_src/examples/pagination_search/db.py"
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! info "Deploy a Postgres DB with docker"
|
||||||
|
```bash
|
||||||
|
docker run -d --name postgres -e POSTGRES_USER=postgres -e POSTGRES_PASSWORD=postgres -e POSTGRES_DB=postgres -p 5432:5432 postgres:18-alpine
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## App
|
||||||
|
|
||||||
|
```python title="app.py"
|
||||||
|
--8<-- "docs_src/examples/pagination_search/app.py"
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Routes
|
||||||
|
### Offset pagination
|
||||||
|
|
||||||
|
Best for admin panels or any UI that needs a total item count and numbered pages.
|
||||||
|
|
||||||
|
```python title="routes.py:1:36"
|
||||||
|
--8<-- "docs_src/examples/pagination_search/routes.py:1:36"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example request**
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /articles/offset?page=2&items_per_page=10&search=fastapi&status=published&order_by=title&order=asc
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example response**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"data": [
|
||||||
|
{ "id": "3f47ac69-...", "title": "FastAPI tips", "status": "published", ... }
|
||||||
|
],
|
||||||
|
"pagination": {
|
||||||
|
"total_count": 42,
|
||||||
|
"page": 2,
|
||||||
|
"items_per_page": 10,
|
||||||
|
"has_more": true
|
||||||
|
},
|
||||||
|
"filter_attributes": {
|
||||||
|
"status": ["archived", "draft", "published"],
|
||||||
|
"name": ["backend", "frontend", "python"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
`filter_attributes` always reflects the values visible **after** applying the active filters. Use it to populate filter dropdowns on the client.
|
||||||
|
|
||||||
|
### Cursor pagination
|
||||||
|
|
||||||
|
Best for feeds, infinite scroll, or any high-throughput API where offset performance degrades.
|
||||||
|
|
||||||
|
```python title="routes.py:39:59"
|
||||||
|
--8<-- "docs_src/examples/pagination_search/routes.py:39:59"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example request**
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /articles/cursor?items_per_page=10&status=published&order_by=created_at&order=desc
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example response**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"data": [
|
||||||
|
{ "id": "3f47ac69-...", "title": "FastAPI tips", "status": "published", ... }
|
||||||
|
],
|
||||||
|
"pagination": {
|
||||||
|
"next_cursor": "eyJ2YWx1ZSI6ICIzZjQ3YWM2OS0uLi4ifQ==",
|
||||||
|
"prev_cursor": null,
|
||||||
|
"items_per_page": 10,
|
||||||
|
"has_more": true
|
||||||
|
},
|
||||||
|
"filter_attributes": {
|
||||||
|
"status": ["published"],
|
||||||
|
"name": ["backend", "python"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Pass `next_cursor` as the `cursor` query parameter on the next request to advance to the next page.
|
||||||
|
|
||||||
|
## Search behaviour
|
||||||
|
|
||||||
|
Both endpoints inherit the same `searchable_fields` declared on `ArticleCrud`:
|
||||||
|
|
||||||
|
Search is **case-insensitive** and uses a `LIKE %query%` pattern. Pass a [`SearchConfig`](../reference/crud.md#fastapi_toolsets.crud.search.SearchConfig) instead of a plain string to control case sensitivity or switch to `match_mode="all"` (AND across all fields instead of OR).
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.crud import SearchConfig
|
||||||
|
|
||||||
|
# Both title AND body must contain "fastapi"
|
||||||
|
result = await ArticleCrud.offset_paginate(
|
||||||
|
session,
|
||||||
|
search=SearchConfig(query="fastapi", case_sensitive=True, match_mode="all"),
|
||||||
|
search_fields=[Article.title, Article.body],
|
||||||
|
)
|
||||||
|
```
|
||||||
67
docs/index.md
Normal file
67
docs/index.md
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
# FastAPI Toolsets
|
||||||
|
|
||||||
|
A modular collection of production-ready utilities for FastAPI. Install only what you need — from async CRUD and database helpers to CLI tooling, Prometheus metrics, and pytest fixtures. Each module is independently installable via optional extras, keeping your dependency footprint minimal.
|
||||||
|
|
||||||
|
[](https://github.com/d3vyce/fastapi-toolsets/actions/workflows/ci.yml)
|
||||||
|
[](https://codecov.io/gh/d3vyce/fastapi-toolsets)
|
||||||
|
[](https://github.com/astral-sh/ty)
|
||||||
|
[](https://github.com/astral-sh/uv)
|
||||||
|
[](https://github.com/astral-sh/ruff)
|
||||||
|
[](https://www.python.org/downloads/)
|
||||||
|
[](https://opensource.org/licenses/MIT)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Documentation**: [https://fastapi-toolsets.d3vyce.fr](https://fastapi-toolsets.d3vyce.fr)
|
||||||
|
|
||||||
|
**Source Code**: [https://github.com/d3vyce/fastapi-toolsets](https://github.com/d3vyce/fastapi-toolsets)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
The base package includes the core modules (CRUD, database, schemas, exceptions, fixtures, dependencies, logging):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add fastapi-toolsets
|
||||||
|
```
|
||||||
|
|
||||||
|
Install only the extras you need:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add "fastapi-toolsets[cli]" # CLI (typer)
|
||||||
|
uv add "fastapi-toolsets[metrics]" # Prometheus metrics (prometheus_client)
|
||||||
|
uv add "fastapi-toolsets[pytest]" # Pytest helpers (httpx, pytest-xdist)
|
||||||
|
```
|
||||||
|
|
||||||
|
Or install everything:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add "fastapi-toolsets[all]"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
### Core
|
||||||
|
|
||||||
|
- **CRUD**: Generic async CRUD operations with `CrudFactory`, built-in full-text/faceted search and offset/cursor pagination.
|
||||||
|
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
|
||||||
|
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
|
||||||
|
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration
|
||||||
|
- **Standardized API Responses**: Consistent response format with `Response`, `PaginatedResponse`, and `PydanticBase`
|
||||||
|
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
|
||||||
|
- **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger`
|
||||||
|
|
||||||
|
### Optional
|
||||||
|
|
||||||
|
- **CLI**: Django-like command-line interface with fixture management and custom commands support
|
||||||
|
- **Metrics**: Prometheus metrics endpoint with provider/collector registry
|
||||||
|
- **Pytest Helpers**: Async test client, database session management, `pytest-xdist` support, and table cleanup utilities
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT License - see [LICENSE](LICENSE) for details.
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Contributions are welcome! Please feel free to submit issues and pull requests.
|
||||||
93
docs/module/cli.md
Normal file
93
docs/module/cli.md
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
# CLI
|
||||||
|
|
||||||
|
Typer-based command-line interface for managing your FastAPI application, with built-in fixture commands integration.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
=== "uv"
|
||||||
|
``` bash
|
||||||
|
uv add "fastapi-toolsets[cli]"
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "pip"
|
||||||
|
``` bash
|
||||||
|
pip install "fastapi-toolsets[cli]"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `cli` module provides a `manager` entry point built with [Typer](https://typer.tiangolo.com/). It allow custom commands to be added in addition of the fixture commands when a [`FixtureRegistry`](../reference/fixtures.md#fastapi_toolsets.fixtures.registry.FixtureRegistry) and a database context are configured.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Configure the CLI in your `pyproject.toml`:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[tool.fastapi-toolsets]
|
||||||
|
cli = "myapp.cli:cli" # Custom Typer app
|
||||||
|
fixtures = "myapp.fixtures:registry" # FixtureRegistry instance
|
||||||
|
db_context = "myapp.db:db_context" # Async context manager for sessions
|
||||||
|
```
|
||||||
|
|
||||||
|
All fields are optional. Without configuration, the `manager` command still works but no command are available.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Manager commands
|
||||||
|
manager --help
|
||||||
|
|
||||||
|
Usage: manager [OPTIONS] COMMAND [ARGS]...
|
||||||
|
|
||||||
|
FastAPI utilities CLI.
|
||||||
|
|
||||||
|
╭─ Options ────────────────────────────────────────────────────────────────────────╮
|
||||||
|
│ --install-completion Install completion for the current shell. │
|
||||||
|
│ --show-completion Show completion for the current shell, to copy it │
|
||||||
|
│ or customize the installation. │
|
||||||
|
│ --help Show this message and exit. │
|
||||||
|
╰──────────────────────────────────────────────────────────────────────────────────╯
|
||||||
|
╭─ Commands ───────────────────────────────────────────────────────────────────────╮
|
||||||
|
│ check-db │
|
||||||
|
│ fixtures Manage database fixtures. │
|
||||||
|
╰──────────────────────────────────────────────────────────────────────────────────╯
|
||||||
|
|
||||||
|
# Fixtures commands
|
||||||
|
manager fixtures --help
|
||||||
|
|
||||||
|
Usage: manager fixtures [OPTIONS] COMMAND [ARGS]...
|
||||||
|
|
||||||
|
Manage database fixtures.
|
||||||
|
|
||||||
|
╭─ Options ────────────────────────────────────────────────────────────────────────╮
|
||||||
|
│ --help Show this message and exit. │
|
||||||
|
╰──────────────────────────────────────────────────────────────────────────────────╯
|
||||||
|
╭─ Commands ───────────────────────────────────────────────────────────────────────╮
|
||||||
|
│ list List all registered fixtures. │
|
||||||
|
│ load Load fixtures into the database. │
|
||||||
|
╰──────────────────────────────────────────────────────────────────────────────────╯
|
||||||
|
```
|
||||||
|
|
||||||
|
## Custom CLI
|
||||||
|
|
||||||
|
You can extend the CLI by providing your own Typer app. The `manager` entry point will merge your app's commands with the built-in ones:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# myapp/cli.py
|
||||||
|
import typer
|
||||||
|
|
||||||
|
cli = typer.Typer()
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
def hello():
|
||||||
|
print("Hello from my app!")
|
||||||
|
```
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[tool.fastapi-toolsets]
|
||||||
|
cli = "myapp.cli:cli"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/cli.md)
|
||||||
479
docs/module/crud.md
Normal file
479
docs/module/crud.md
Normal file
@@ -0,0 +1,479 @@
|
|||||||
|
# CRUD
|
||||||
|
|
||||||
|
Generic async CRUD operations for SQLAlchemy models with search, pagination, and many-to-many support.
|
||||||
|
|
||||||
|
!!! info
|
||||||
|
This module has been coded and tested to be compatible with PostgreSQL only.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `crud` module provides [`AsyncCrud`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud), an abstract base class with a full suite of async database operations, and [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory), a convenience function to instantiate it for a given model.
|
||||||
|
|
||||||
|
## Creating a CRUD class
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.crud import CrudFactory
|
||||||
|
from myapp.models import User
|
||||||
|
|
||||||
|
UserCrud = CrudFactory(model=User)
|
||||||
|
```
|
||||||
|
|
||||||
|
[`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory) dynamically creates a class named `AsyncUserCrud` with `User` as its model.
|
||||||
|
|
||||||
|
## Basic operations
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Create
|
||||||
|
user = await UserCrud.create(session=session, obj=UserCreateSchema(username="alice"))
|
||||||
|
|
||||||
|
# Get one (raises NotFoundError if not found)
|
||||||
|
user = await UserCrud.get(session=session, filters=[User.id == user_id])
|
||||||
|
|
||||||
|
# Get first or None
|
||||||
|
user = await UserCrud.first(session=session, filters=[User.email == email])
|
||||||
|
|
||||||
|
# Get multiple
|
||||||
|
users = await UserCrud.get_multi(session=session, filters=[User.is_active == True])
|
||||||
|
|
||||||
|
# Update
|
||||||
|
user = await UserCrud.update(session=session, obj=UserUpdateSchema(username="bob"), filters=[User.id == user_id])
|
||||||
|
|
||||||
|
# Delete
|
||||||
|
await UserCrud.delete(session=session, filters=[User.id == user_id])
|
||||||
|
|
||||||
|
# Count / exists
|
||||||
|
count = await UserCrud.count(session=session, filters=[User.is_active == True])
|
||||||
|
exists = await UserCrud.exists(session=session, filters=[User.email == email])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pagination
|
||||||
|
|
||||||
|
!!! info "Added in `v1.1` (only offset_pagination via `paginate` if `<v1.1`)"
|
||||||
|
|
||||||
|
Two pagination strategies are available. Both return a [`PaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse) but differ in how they navigate through results.
|
||||||
|
|
||||||
|
| | `offset_paginate` | `cursor_paginate` |
|
||||||
|
|---|---|---|
|
||||||
|
| Total count | Yes | No |
|
||||||
|
| Jump to arbitrary page | Yes | No |
|
||||||
|
| Performance on deep pages | Degrades | Constant |
|
||||||
|
| Stable under concurrent inserts | No | Yes |
|
||||||
|
| Search compatible | Yes | Yes |
|
||||||
|
| Use case | Admin panels, numbered pagination | Feeds, APIs, infinite scroll |
|
||||||
|
|
||||||
|
### Offset pagination
|
||||||
|
|
||||||
|
```python
|
||||||
|
@router.get(
|
||||||
|
"",
|
||||||
|
response_model=PaginatedResponse[User],
|
||||||
|
)
|
||||||
|
async def get_users(
|
||||||
|
session: SessionDep,
|
||||||
|
items_per_page: int = 50,
|
||||||
|
page: int = 1,
|
||||||
|
):
|
||||||
|
return await crud.UserCrud.offset_paginate(
|
||||||
|
session=session,
|
||||||
|
items_per_page=items_per_page,
|
||||||
|
page=page,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
The [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) method returns a [`PaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse) whose `pagination` field is an [`OffsetPagination`](../reference/schemas.md#fastapi_toolsets.schemas.OffsetPagination) object:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"data": ["..."],
|
||||||
|
"pagination": {
|
||||||
|
"total_count": 100,
|
||||||
|
"page": 1,
|
||||||
|
"items_per_page": 20,
|
||||||
|
"has_more": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! warning "Deprecated: `paginate`"
|
||||||
|
The `paginate` function is a backward-compatible alias for `offset_paginate`. This function is **deprecated** and will be removed in **v2.0**.
|
||||||
|
|
||||||
|
### Cursor pagination
|
||||||
|
|
||||||
|
```python
|
||||||
|
@router.get(
|
||||||
|
"",
|
||||||
|
response_model=PaginatedResponse[UserRead],
|
||||||
|
)
|
||||||
|
async def list_users(
|
||||||
|
session: SessionDep,
|
||||||
|
cursor: str | None = None,
|
||||||
|
items_per_page: int = 20,
|
||||||
|
):
|
||||||
|
return await UserCrud.cursor_paginate(
|
||||||
|
session=session,
|
||||||
|
cursor=cursor,
|
||||||
|
items_per_page=items_per_page,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
The [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate) method returns a [`PaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse) whose `pagination` field is a [`CursorPagination`](../reference/schemas.md#fastapi_toolsets.schemas.CursorPagination) object:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"data": ["..."],
|
||||||
|
"pagination": {
|
||||||
|
"next_cursor": "eyJ2YWx1ZSI6ICIzZjQ3YWM2OS0uLi4ifQ==",
|
||||||
|
"prev_cursor": null,
|
||||||
|
"items_per_page": 20,
|
||||||
|
"has_more": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Pass `next_cursor` as the `cursor` query parameter on the next request to advance to the next page. `prev_cursor` is set on pages 2+ and points back to the first item of the current page. Both are `null` when there is no adjacent page.
|
||||||
|
|
||||||
|
#### Choosing a cursor column
|
||||||
|
|
||||||
|
The cursor column is set once on [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory) via the `cursor_column` parameter. It must be monotonically ordered for stable results:
|
||||||
|
|
||||||
|
- Auto-increment integer PKs
|
||||||
|
- UUID v7 PKs
|
||||||
|
- Timestamps
|
||||||
|
|
||||||
|
!!! warning
|
||||||
|
Random UUID v4 PKs are **not** suitable as cursor columns because their ordering is non-deterministic.
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
`cursor_column` is required. Calling [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate) on a CRUD class that has no `cursor_column` configured raises a `ValueError`.
|
||||||
|
|
||||||
|
The cursor value is base64-encoded when returned to the client and decoded back to the correct Python type on the next request. The following SQLAlchemy column types are supported:
|
||||||
|
|
||||||
|
| SQLAlchemy type | Python type |
|
||||||
|
|---|---|
|
||||||
|
| `Integer`, `BigInteger`, `SmallInteger` | `int` |
|
||||||
|
| `Uuid` | `uuid.UUID` |
|
||||||
|
| `DateTime` | `datetime.datetime` |
|
||||||
|
| `Date` | `datetime.date` |
|
||||||
|
| `Float`, `Numeric` | `decimal.Decimal` |
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Paginate by the primary key
|
||||||
|
PostCrud = CrudFactory(model=Post, cursor_column=Post.id)
|
||||||
|
|
||||||
|
# Paginate by a timestamp column instead
|
||||||
|
PostCrud = CrudFactory(model=Post, cursor_column=Post.created_at)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Search
|
||||||
|
|
||||||
|
Two search strategies are available, both compatible with [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) and [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate).
|
||||||
|
|
||||||
|
| | Full-text search | Faceted search |
|
||||||
|
|---|---|---|
|
||||||
|
| Input | Free-text string | Exact column values |
|
||||||
|
| Relationship support | Yes | Yes |
|
||||||
|
| Use case | Search bars | Filter dropdowns |
|
||||||
|
|
||||||
|
!!! info "You can use both search strategies in the same endpoint!"
|
||||||
|
|
||||||
|
### Full-text search
|
||||||
|
|
||||||
|
Declare `searchable_fields` on the CRUD class. Relationship traversal is supported via tuples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
PostCrud = CrudFactory(
|
||||||
|
model=Post,
|
||||||
|
searchable_fields=[
|
||||||
|
Post.title,
|
||||||
|
Post.content,
|
||||||
|
(Post.author, User.username), # search across relationship
|
||||||
|
],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
You can override `searchable_fields` per call with `search_fields`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
result = await UserCrud.offset_paginate(
|
||||||
|
session=session,
|
||||||
|
search_fields=[User.country],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
This allows searching with both [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) and [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate):
|
||||||
|
|
||||||
|
```python
|
||||||
|
@router.get(
|
||||||
|
"",
|
||||||
|
response_model=PaginatedResponse[User],
|
||||||
|
)
|
||||||
|
async def get_users(
|
||||||
|
session: SessionDep,
|
||||||
|
items_per_page: int = 50,
|
||||||
|
page: int = 1,
|
||||||
|
search: str | None = None,
|
||||||
|
):
|
||||||
|
return await crud.UserCrud.offset_paginate(
|
||||||
|
session=session,
|
||||||
|
items_per_page=items_per_page,
|
||||||
|
page=page,
|
||||||
|
search=search,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
@router.get(
|
||||||
|
"",
|
||||||
|
response_model=PaginatedResponse[User],
|
||||||
|
)
|
||||||
|
async def get_users(
|
||||||
|
session: SessionDep,
|
||||||
|
cursor: str | None = None,
|
||||||
|
items_per_page: int = 50,
|
||||||
|
search: str | None = None,
|
||||||
|
):
|
||||||
|
return await crud.UserCrud.cursor_paginate(
|
||||||
|
session=session,
|
||||||
|
items_per_page=items_per_page,
|
||||||
|
cursor=cursor,
|
||||||
|
search=search,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Faceted search
|
||||||
|
|
||||||
|
!!! info "Added in `v1.2`"
|
||||||
|
|
||||||
|
Declare `facet_fields` on the CRUD class to return distinct column values alongside paginated results. This is useful for populating filter dropdowns or building faceted search UIs.
|
||||||
|
|
||||||
|
Facet fields use the same syntax as `searchable_fields` — direct columns or relationship tuples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
UserCrud = CrudFactory(
|
||||||
|
model=User,
|
||||||
|
facet_fields=[
|
||||||
|
User.status,
|
||||||
|
User.country,
|
||||||
|
(User.role, Role.name), # value from a related model
|
||||||
|
],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
You can override `facet_fields` per call:
|
||||||
|
|
||||||
|
```python
|
||||||
|
result = await UserCrud.offset_paginate(
|
||||||
|
session=session,
|
||||||
|
facet_fields=[User.country],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
The distinct values are returned in the `filter_attributes` field of [`PaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse):
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"data": ["..."],
|
||||||
|
"pagination": { "..." },
|
||||||
|
"filter_attributes": {
|
||||||
|
"status": ["active", "inactive"],
|
||||||
|
"country": ["DE", "FR", "US"],
|
||||||
|
"name": ["admin", "editor", "viewer"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Use `filter_by` to pass the client's chosen filter values directly — no need to build SQLAlchemy conditions by hand. Any unknown key raises [`InvalidFacetFilterError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError).
|
||||||
|
|
||||||
|
!!! info "The keys in `filter_by` are the same keys the client received in `filter_attributes`."
|
||||||
|
Keys are normally the terminal `column.key` (e.g. `"name"` for `Role.name`). When two facet fields share the same column key (e.g. `(Build.project, Project.name)` and `(Build.os, Os.name)`), the relationship name is prepended automatically: `"project__name"` and `"os__name"`.
|
||||||
|
|
||||||
|
`filter_by` and `filters` can be combined — both are applied with AND logic.
|
||||||
|
|
||||||
|
Use [`filter_params()`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.filter_params) to generate a dict with the facet filter values from the query parameters:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
|
||||||
|
UserCrud = CrudFactory(
|
||||||
|
model=User,
|
||||||
|
facet_fields=[User.status, User.country, (User.role, Role.name)],
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("", response_model_exclude_none=True)
|
||||||
|
async def list_users(
|
||||||
|
session: SessionDep,
|
||||||
|
page: int = 1,
|
||||||
|
filter_by: Annotated[dict[str, list[str]], Depends(UserCrud.filter_params())],
|
||||||
|
) -> PaginatedResponse[UserRead]:
|
||||||
|
return await UserCrud.offset_paginate(
|
||||||
|
session=session,
|
||||||
|
page=page,
|
||||||
|
filter_by=filter_by,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Both single-value and multi-value query parameters work:
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /users?status=active → filter_by={"status": ["active"]}
|
||||||
|
GET /users?status=active&country=FR → filter_by={"status": ["active"], "country": ["FR"]}
|
||||||
|
GET /users?role=admin&role=editor → filter_by={"role": ["admin", "editor"]} (IN clause)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Sorting
|
||||||
|
|
||||||
|
!!! info "Added in `v1.3`"
|
||||||
|
|
||||||
|
Declare `order_fields` on the CRUD class to expose client-driven column ordering via `order_by` and `order` query parameters.
|
||||||
|
|
||||||
|
```python
|
||||||
|
UserCrud = CrudFactory(
|
||||||
|
model=User,
|
||||||
|
order_fields=[
|
||||||
|
User.name,
|
||||||
|
User.created_at,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Call [`order_params()`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.order_params) to generate a FastAPI dependency that maps the query parameters to an [`OrderByClause`](../reference/crud.md#fastapi_toolsets.crud.factory.OrderByClause) expression:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
from fastapi_toolsets.crud import OrderByClause
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def list_users(
|
||||||
|
session: SessionDep,
|
||||||
|
order_by: Annotated[OrderByClause | None, Depends(UserCrud.order_params())],
|
||||||
|
) -> PaginatedResponse[UserRead]:
|
||||||
|
return await UserCrud.offset_paginate(session=session, order_by=order_by)
|
||||||
|
```
|
||||||
|
|
||||||
|
The dependency adds two query parameters to the endpoint:
|
||||||
|
|
||||||
|
| Parameter | Type |
|
||||||
|
| ---------- | --------------- |
|
||||||
|
| `order_by` | `str | null` |
|
||||||
|
| `order` | `asc` or `desc` |
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /users?order_by=name&order=asc → ORDER BY users.name ASC
|
||||||
|
GET /users?order_by=name&order=desc → ORDER BY users.name DESC
|
||||||
|
```
|
||||||
|
|
||||||
|
An unknown `order_by` value raises [`InvalidOrderFieldError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidOrderFieldError) (HTTP 422).
|
||||||
|
|
||||||
|
You can also pass `order_fields` directly to `order_params()` to override the class-level defaults without modifying them:
|
||||||
|
|
||||||
|
```python
|
||||||
|
UserOrderParams = UserCrud.order_params(order_fields=[User.name])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Relationship loading
|
||||||
|
|
||||||
|
!!! info "Added in `v1.1`"
|
||||||
|
|
||||||
|
By default, SQLAlchemy relationships are not loaded unless explicitly requested. Instead of using `lazy="selectin"` on model definitions (which is implicit and applies globally), define a `default_load_options` on the CRUD class to control loading strategy explicitly.
|
||||||
|
|
||||||
|
!!! warning
|
||||||
|
Avoid using `lazy="selectin"` on model relationships. It fires silently on every query, cannot be disabled per-call, and can cause unexpected cascading loads through deep relationship chains. Use `default_load_options` instead.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
ArticleCrud = CrudFactory(
|
||||||
|
model=Article,
|
||||||
|
default_load_options=[
|
||||||
|
selectinload(Article.category),
|
||||||
|
selectinload(Article.tags),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
`default_load_options` applies automatically to all read operations (`get`, `first`, `get_multi`, `offset_paginate`, `cursor_paginate`). When `load_options` is passed at call-site, it **fully replaces** `default_load_options` for that query — giving you precise per-call control:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Only loads category, tags are not loaded
|
||||||
|
article = await ArticleCrud.get(
|
||||||
|
session=session,
|
||||||
|
filters=[Article.id == article_id],
|
||||||
|
load_options=[selectinload(Article.category)],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Loads nothing — useful for write-then-refresh flows or lightweight checks
|
||||||
|
articles = await ArticleCrud.get_multi(session=session, load_options=[])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Many-to-many relationships
|
||||||
|
|
||||||
|
Use `m2m_fields` to map schema fields containing lists of IDs to SQLAlchemy relationships. The CRUD class resolves and validates all IDs before persisting:
|
||||||
|
|
||||||
|
```python
|
||||||
|
PostCrud = CrudFactory(
|
||||||
|
model=Post,
|
||||||
|
m2m_fields={"tag_ids": Post.tags},
|
||||||
|
)
|
||||||
|
|
||||||
|
post = await PostCrud.create(session=session, obj=PostCreateSchema(title="Hello", tag_ids=[1, 2, 3]))
|
||||||
|
```
|
||||||
|
|
||||||
|
## Upsert
|
||||||
|
|
||||||
|
Atomic `INSERT ... ON CONFLICT DO UPDATE` using PostgreSQL:
|
||||||
|
|
||||||
|
```python
|
||||||
|
await UserCrud.upsert(
|
||||||
|
session=session,
|
||||||
|
obj=UserCreateSchema(email="alice@example.com", username="alice"),
|
||||||
|
index_elements=[User.email],
|
||||||
|
set_={"username"},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Response serialization
|
||||||
|
|
||||||
|
!!! info "Added in `v1.1`"
|
||||||
|
|
||||||
|
Pass a Pydantic schema class to `create`, `get`, `update`, or `offset_paginate` to serialize the result directly into that schema and wrap it in a [`Response[schema]`](../reference/schemas.md#fastapi_toolsets.schemas.Response) or [`PaginatedResponse[schema]`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse):
|
||||||
|
|
||||||
|
```python
|
||||||
|
class UserRead(PydanticBase):
|
||||||
|
id: UUID
|
||||||
|
username: str
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/{uuid}",
|
||||||
|
responses=generate_error_responses(NotFoundError),
|
||||||
|
)
|
||||||
|
async def get_user(session: SessionDep, uuid: UUID) -> Response[UserRead]:
|
||||||
|
return await crud.UserCrud.get(
|
||||||
|
session=session,
|
||||||
|
filters=[User.id == uuid],
|
||||||
|
schema=UserRead,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def list_users(session: SessionDep, page: int = 1) -> PaginatedResponse[UserRead]:
|
||||||
|
return await crud.UserCrud.offset_paginate(
|
||||||
|
session=session,
|
||||||
|
page=page,
|
||||||
|
schema=UserRead,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
The schema must have `from_attributes=True` (or inherit from [`PydanticBase`](../reference/schemas.md#fastapi_toolsets.schemas.PydanticBase)) so it can be built from SQLAlchemy model instances.
|
||||||
|
|
||||||
|
!!! warning "Deprecated: `as_response`"
|
||||||
|
The `as_response=True` parameter is **deprecated** and will be removed in **v2.0**. Replace it with `schema=YourSchema`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/crud.md)
|
||||||
92
docs/module/db.md
Normal file
92
docs/module/db.md
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
# DB
|
||||||
|
|
||||||
|
SQLAlchemy async session management with transactions, table locking, and row-change polling.
|
||||||
|
|
||||||
|
!!! info
|
||||||
|
This module has been coded and tested to be compatible with PostgreSQL only.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `db` module provides helpers to create FastAPI dependencies and context managers for `AsyncSession`, along with utilities for nested transactions, table lock and polling for row changes.
|
||||||
|
|
||||||
|
## Session dependency
|
||||||
|
|
||||||
|
Use [`create_db_dependency`](../reference/db.md#fastapi_toolsets.db.create_db_dependency) to create a FastAPI dependency that yields a session and auto-commits on success:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||||
|
from fastapi_toolsets.db import create_db_dependency
|
||||||
|
|
||||||
|
engine = create_async_engine(url="postgresql+asyncpg://...", future=True)
|
||||||
|
session_maker = async_sessionmaker(bind=engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
get_db = create_db_dependency(session_maker=session_maker)
|
||||||
|
|
||||||
|
@router.get("/users")
|
||||||
|
async def list_users(session: AsyncSession = Depends(get_db)):
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Session context manager
|
||||||
|
|
||||||
|
Use [`create_db_context`](../reference/db.md#fastapi_toolsets.db.create_db_context) for sessions outside request handlers (e.g. background tasks, CLI commands):
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import create_db_context
|
||||||
|
|
||||||
|
db_context = create_db_context(session_maker=session_maker)
|
||||||
|
|
||||||
|
async def seed():
|
||||||
|
async with db_context() as session:
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Nested transactions
|
||||||
|
|
||||||
|
[`get_transaction`](../reference/db.md#fastapi_toolsets.db.get_transaction) handles savepoints automatically, allowing safe nesting:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import get_transaction
|
||||||
|
|
||||||
|
async def create_user_with_role(session=session):
|
||||||
|
async with get_transaction(session=session):
|
||||||
|
...
|
||||||
|
async with get_transaction(session=session): # uses savepoint
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Table locking
|
||||||
|
|
||||||
|
[`lock_tables`](../reference/db.md#fastapi_toolsets.db.lock_tables) acquires PostgreSQL table-level locks before executing critical sections:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import lock_tables
|
||||||
|
|
||||||
|
async with lock_tables(session=session, tables=[User], mode="EXCLUSIVE"):
|
||||||
|
# No other transaction can modify User until this block exits
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
Available lock modes are defined in [`LockMode`](../reference/db.md#fastapi_toolsets.db.LockMode): `ACCESS_SHARE`, `ROW_SHARE`, `ROW_EXCLUSIVE`, `SHARE_UPDATE_EXCLUSIVE`, `SHARE`, `SHARE_ROW_EXCLUSIVE`, `EXCLUSIVE`, `ACCESS_EXCLUSIVE`.
|
||||||
|
|
||||||
|
## Row-change polling
|
||||||
|
|
||||||
|
[`wait_for_row_change`](../reference/db.md#fastapi_toolsets.db.wait_for_row_change) polls a row until a specific column changes value, useful for waiting on async side effects:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import wait_for_row_change
|
||||||
|
|
||||||
|
# Wait up to 30s for order.status to change
|
||||||
|
await wait_for_row_change(
|
||||||
|
session=session,
|
||||||
|
model=Order,
|
||||||
|
pk_value=order_id,
|
||||||
|
columns=[Order.status],
|
||||||
|
interval=1.0,
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/db.md)
|
||||||
50
docs/module/dependencies.md
Normal file
50
docs/module/dependencies.md
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
# Dependencies
|
||||||
|
|
||||||
|
FastAPI dependency factories for automatic model resolution from path and body parameters.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `dependencies` module provides two factory functions that create FastAPI dependencies to fetch a model instance from the database automatically — either from a path parameter or from a request body field — and inject it directly into your route handler.
|
||||||
|
|
||||||
|
## `PathDependency`
|
||||||
|
|
||||||
|
[`PathDependency`](../reference/dependencies.md#fastapi_toolsets.dependencies.PathDependency) resolves a model from a URL path parameter and injects it into the route handler. Raises [`NotFoundError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NotFoundError) automatically if the record does not exist.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.dependencies import PathDependency
|
||||||
|
|
||||||
|
UserDep = PathDependency(model=User, field=User.id, session_dep=get_db)
|
||||||
|
|
||||||
|
@router.get("/users/{user_id}")
|
||||||
|
async def get_user(user: User = UserDep):
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
By default the parameter name is inferred from the field (`user_id` for `User.id`). You can override it:
|
||||||
|
|
||||||
|
```python
|
||||||
|
UserDep = PathDependency(model=User, field=User.id, session_dep=get_db, param_name="id")
|
||||||
|
|
||||||
|
@router.get("/users/{id}")
|
||||||
|
async def get_user(user: User = UserDep):
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
## `BodyDependency`
|
||||||
|
|
||||||
|
[`BodyDependency`](../reference/dependencies.md#fastapi_toolsets.dependencies.BodyDependency) resolves a model from a field in the request body. Useful when a body contains a foreign key and you want the full object injected:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.dependencies import BodyDependency
|
||||||
|
|
||||||
|
RoleDep = BodyDependency(model=Role, field=Role.id, session_dep=get_db, body_field="role_id")
|
||||||
|
|
||||||
|
@router.post("/users")
|
||||||
|
async def create_user(body: UserCreateSchema, role: Role = RoleDep):
|
||||||
|
user = User(username=body.username, role=role)
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/dependencies.md)
|
||||||
86
docs/module/exceptions.md
Normal file
86
docs/module/exceptions.md
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
# Exceptions
|
||||||
|
|
||||||
|
Structured API exceptions with consistent error responses and automatic OpenAPI documentation.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `exceptions` module provides a set of pre-built HTTP exceptions and a FastAPI exception handler that formats all errors — including validation errors — into a uniform [`ErrorResponse`](../reference/schemas.md#fastapi_toolsets.schemas.ErrorResponse).
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
Register the exception handlers on your FastAPI app at startup:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app=app)
|
||||||
|
```
|
||||||
|
|
||||||
|
This registers handlers for:
|
||||||
|
|
||||||
|
- [`ApiException`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ApiException) — all custom exceptions below
|
||||||
|
- `RequestValidationError` — Pydantic request validation (422)
|
||||||
|
- `ResponseValidationError` — Pydantic response validation (422)
|
||||||
|
- `Exception` — unhandled errors (500)
|
||||||
|
|
||||||
|
## Built-in exceptions
|
||||||
|
|
||||||
|
| Exception | Status | Default message |
|
||||||
|
|-----------|--------|-----------------|
|
||||||
|
| [`UnauthorizedError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.UnauthorizedError) | 401 | Unauthorized |
|
||||||
|
| [`ForbiddenError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ForbiddenError) | 403 | Forbidden |
|
||||||
|
| [`NotFoundError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NotFoundError) | 404 | Not found |
|
||||||
|
| [`ConflictError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ConflictError) | 409 | Conflict |
|
||||||
|
| [`NoSearchableFieldsError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NoSearchableFieldsError) | 400 | No searchable fields |
|
||||||
|
| [`InvalidFacetFilterError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError) | 400 | Invalid facet filter |
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.exceptions import NotFoundError
|
||||||
|
|
||||||
|
@router.get("/users/{id}")
|
||||||
|
async def get_user(id: int, session: AsyncSession = Depends(get_db)):
|
||||||
|
user = await UserCrud.first(session=session, filters=[User.id == id])
|
||||||
|
if not user:
|
||||||
|
raise NotFoundError
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
## Custom exceptions
|
||||||
|
|
||||||
|
Subclass [`ApiException`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ApiException) and define an `api_error` class variable:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.exceptions import ApiException
|
||||||
|
from fastapi_toolsets.schemas import ApiError
|
||||||
|
|
||||||
|
class PaymentRequiredError(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=402,
|
||||||
|
msg="Payment required",
|
||||||
|
desc="Your subscription has expired.",
|
||||||
|
err_code="PAYMENT_REQUIRED",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## OpenAPI response documentation
|
||||||
|
|
||||||
|
Use [`generate_error_responses`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.generate_error_responses) to add error schemas to your endpoint's OpenAPI spec:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.exceptions import generate_error_responses, NotFoundError, ForbiddenError
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/users/{id}",
|
||||||
|
responses=generate_error_responses(NotFoundError, ForbiddenError),
|
||||||
|
)
|
||||||
|
async def get_user(...): ...
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! info
|
||||||
|
The pydantic validation error is automatically added by FastAPI.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/exceptions.md)
|
||||||
123
docs/module/fixtures.md
Normal file
123
docs/module/fixtures.md
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
# Fixtures
|
||||||
|
|
||||||
|
Dependency-aware database seeding with context-based loading strategies.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `fixtures` module lets you define named fixtures with dependencies between them, then load them into the database in the correct order. Fixtures can be scoped to contexts (e.g. base data, testing data) so that only the relevant ones are loaded for each environment.
|
||||||
|
|
||||||
|
## Defining fixtures
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.fixtures import FixtureRegistry, Context
|
||||||
|
|
||||||
|
fixtures = FixtureRegistry()
|
||||||
|
|
||||||
|
@fixtures.register
|
||||||
|
def roles():
|
||||||
|
return [
|
||||||
|
Role(id=1, name="admin"),
|
||||||
|
Role(id=2, name="user"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@fixtures.register(depends_on=["roles"], contexts=[Context.TESTING])
|
||||||
|
def test_users():
|
||||||
|
return [
|
||||||
|
User(id=1, username="alice", role_id=1),
|
||||||
|
User(id=2, username="bob", role_id=2),
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
Dependencies declared via `depends_on` are resolved topologically — `roles` will always be loaded before `test_users`.
|
||||||
|
|
||||||
|
## Loading fixtures
|
||||||
|
|
||||||
|
By context with [`load_fixtures_by_context`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.load_fixtures_by_context):
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.fixtures import load_fixtures_by_context
|
||||||
|
|
||||||
|
async with db_context() as session:
|
||||||
|
await load_fixtures_by_context(session=session, registry=fixtures, context=Context.TESTING)
|
||||||
|
```
|
||||||
|
|
||||||
|
Directly with [`load_fixtures`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.load_fixtures):
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.fixtures import load_fixtures
|
||||||
|
|
||||||
|
async with db_context() as session:
|
||||||
|
await load_fixtures(session=session, registry=fixtures)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Contexts
|
||||||
|
|
||||||
|
[`Context`](../reference/fixtures.md#fastapi_toolsets.fixtures.enum.Context) is an enum with predefined values:
|
||||||
|
|
||||||
|
| Context | Description |
|
||||||
|
|---------|-------------|
|
||||||
|
| `Context.BASE` | Core data required in all environments |
|
||||||
|
| `Context.TESTING` | Data only loaded during tests |
|
||||||
|
| `Context.PRODUCTION` | Data only loaded in production |
|
||||||
|
|
||||||
|
A fixture with no `contexts` defined takes `Context.BASE` by default.
|
||||||
|
|
||||||
|
## Load strategies
|
||||||
|
|
||||||
|
[`LoadStrategy`](../reference/fixtures.md#fastapi_toolsets.fixtures.enum.LoadStrategy) controls how the fixture loader handles rows that already exist:
|
||||||
|
|
||||||
|
| Strategy | Description |
|
||||||
|
|----------|-------------|
|
||||||
|
| `LoadStrategy.INSERT` | Insert only, fail on duplicates |
|
||||||
|
| `LoadStrategy.UPSERT` | Insert or update on conflict |
|
||||||
|
| `LoadStrategy.SKIP` | Skip rows that already exist |
|
||||||
|
|
||||||
|
## Merging registries
|
||||||
|
|
||||||
|
Split fixtures definitions across modules and merge them:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from myapp.fixtures.dev import dev_fixtures
|
||||||
|
from myapp.fixtures.prod import prod_fixtures
|
||||||
|
|
||||||
|
fixtures = fixturesRegistry()
|
||||||
|
fixtures.include_registry(registry=dev_fixtures)
|
||||||
|
fixtures.include_registry(registry=prod_fixtures)
|
||||||
|
|
||||||
|
## Pytest integration
|
||||||
|
|
||||||
|
Use [`register_fixtures`](../reference/pytest.md#fastapi_toolsets.pytest.plugin.register_fixtures) to expose each fixture in your registry as an injectable pytest fixture named `fixture_{name}` by default:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# conftest.py
|
||||||
|
import pytest
|
||||||
|
from fastapi_toolsets.pytest import create_db_session, register_fixtures
|
||||||
|
from app.fixtures import registry
|
||||||
|
from app.models import Base
|
||||||
|
|
||||||
|
DATABASE_URL = "postgresql+asyncpg://user:pass@localhost/test_db"
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session():
|
||||||
|
async with create_db_session(database_url=DATABASE_URL, base=Base, cleanup=True) as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
register_fixtures(registry=registry, namespace=globals())
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# test_users.py
|
||||||
|
async def test_user_can_login(fixture_users: list[User], fixture_roles: list[Role]):
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
The load order is resolved automatically from the `depends_on` declarations in your registry. Each generated fixture receives `db_session` as a dependency and returns the list of loaded model instances.
|
||||||
|
|
||||||
|
## CLI integration
|
||||||
|
|
||||||
|
Fixtures can be triggered from the CLI. See the [CLI module](cli.md) for setup instructions.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/fixtures.md)
|
||||||
39
docs/module/logger.md
Normal file
39
docs/module/logger.md
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# Logger
|
||||||
|
|
||||||
|
Lightweight logging utilities with consistent formatting and uvicorn integration.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `logger` module provides two helpers: one to configure the root logger (and uvicorn loggers) at startup, and one to retrieve a named logger anywhere in your codebase.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
Call [`configure_logging`](../reference/logger.md#fastapi_toolsets.logger.configure_logging) once at application startup:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.logger import configure_logging
|
||||||
|
|
||||||
|
configure_logging(level="INFO")
|
||||||
|
```
|
||||||
|
|
||||||
|
This sets up a stdout handler with a consistent format and also configures uvicorn's access and error loggers so all log output shares the same style.
|
||||||
|
|
||||||
|
## Getting a logger
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__)
|
||||||
|
logger.info("User created")
|
||||||
|
```
|
||||||
|
|
||||||
|
When called without arguments, [`get_logger`](../reference/logger.md#fastapi_toolsets.logger.get_logger) auto-detects the caller's module name via frame inspection:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Equivalent to get_logger(name=__name__)
|
||||||
|
logger = get_logger()
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/logger.md)
|
||||||
86
docs/module/metrics.md
Normal file
86
docs/module/metrics.md
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
# Metrics
|
||||||
|
|
||||||
|
Prometheus metrics integration with a decorator-based registry and multi-process support.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
=== "uv"
|
||||||
|
``` bash
|
||||||
|
uv add "fastapi-toolsets[metrics]"
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "pip"
|
||||||
|
``` bash
|
||||||
|
pip install "fastapi-toolsets[metrics]"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `metrics` module provides a [`MetricsRegistry`](../reference/metrics.md#fastapi_toolsets.metrics.registry.MetricsRegistry) to declare Prometheus metrics with decorators, and an [`init_metrics`](../reference/metrics.md#fastapi_toolsets.metrics.handler.init_metrics) function to mount a `/metrics` endpoint on your FastAPI app.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi_toolsets.metrics import MetricsRegistry, init_metrics
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
metrics = MetricsRegistry()
|
||||||
|
|
||||||
|
init_metrics(app=app, registry=metrics)
|
||||||
|
```
|
||||||
|
|
||||||
|
This mounts the `/metrics` endpoint that Prometheus can scrape.
|
||||||
|
|
||||||
|
## Declaring metrics
|
||||||
|
|
||||||
|
### Providers
|
||||||
|
|
||||||
|
Providers are called once at startup and register metrics that are updated externally (e.g. counters, histograms):
|
||||||
|
|
||||||
|
```python
|
||||||
|
from prometheus_client import Counter, Histogram
|
||||||
|
|
||||||
|
@metrics.register
|
||||||
|
def http_requests():
|
||||||
|
return Counter("http_requests_total", "Total HTTP requests", ["method", "status"])
|
||||||
|
|
||||||
|
@metrics.register
|
||||||
|
def request_duration():
|
||||||
|
return Histogram("request_duration_seconds", "Request duration")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Collectors
|
||||||
|
|
||||||
|
Collectors are called on every scrape. Use them for metrics that reflect current state (e.g. gauges):
|
||||||
|
|
||||||
|
```python
|
||||||
|
@metrics.register(collect=True)
|
||||||
|
def queue_depth():
|
||||||
|
gauge = Gauge("queue_depth", "Current queue depth")
|
||||||
|
gauge.set(get_current_queue_depth())
|
||||||
|
```
|
||||||
|
|
||||||
|
## Merging registries
|
||||||
|
|
||||||
|
Split metrics definitions across modules and merge them:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from myapp.metrics.http import http_metrics
|
||||||
|
from myapp.metrics.db import db_metrics
|
||||||
|
|
||||||
|
metrics = MetricsRegistry()
|
||||||
|
metrics.include_registry(registry=http_metrics)
|
||||||
|
metrics.include_registry(registry=db_metrics)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Multi-process mode
|
||||||
|
|
||||||
|
Multi-process support is enabled automatically when the `PROMETHEUS_MULTIPROC_DIR` environment variable is set. No code changes are required.
|
||||||
|
|
||||||
|
!!! warning "Environment variable name"
|
||||||
|
The correct variable is `PROMETHEUS_MULTIPROC_DIR` (not `PROMETHEUS_MULTIPROCESS_DIR`).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/metrics.md)
|
||||||
86
docs/module/pytest.md
Normal file
86
docs/module/pytest.md
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
# Pytest
|
||||||
|
|
||||||
|
Testing helpers for FastAPI applications with async client, database sessions, and parallel worker support.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
=== "uv"
|
||||||
|
``` bash
|
||||||
|
uv add "fastapi-toolsets[pytest]"
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "pip"
|
||||||
|
``` bash
|
||||||
|
pip install "fastapi-toolsets[pytest]"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `pytest` module provides utilities for setting up async test clients, managing test database sessions, and supporting parallel test execution with `pytest-xdist`.
|
||||||
|
|
||||||
|
## Creating an async client
|
||||||
|
|
||||||
|
Use [`create_async_client`](../reference/pytest.md#fastapi_toolsets.pytest.utils.create_async_client) to get an `httpx.AsyncClient` configured for your FastAPI app:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import create_async_client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def http_client(db_session):
|
||||||
|
async def _override_get_db():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
async with create_async_client(
|
||||||
|
app=app,
|
||||||
|
base_url="http://127.0.0.1/api/v1",
|
||||||
|
dependency_overrides={get_db: _override_get_db},
|
||||||
|
) as c:
|
||||||
|
yield c
|
||||||
|
```
|
||||||
|
|
||||||
|
## Database sessions in tests
|
||||||
|
|
||||||
|
Use [`create_db_session`](../reference/pytest.md#fastapi_toolsets.pytest.utils.create_db_session) to create an isolated `AsyncSession` for a test:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import create_db_session, create_worker_database
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
async def worker_db_url():
|
||||||
|
async with create_worker_database(
|
||||||
|
database_url=str(settings.SQLALCHEMY_DATABASE_URI)
|
||||||
|
) as url:
|
||||||
|
yield url
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session(worker_db_url):
|
||||||
|
async with create_db_session(
|
||||||
|
database_url=worker_db_url, base=Base, cleanup=True
|
||||||
|
) as session:
|
||||||
|
yield session
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! info
|
||||||
|
In this example, the database is reset between each test using the argument `cleanup=True`.
|
||||||
|
|
||||||
|
## Parallel testing with pytest-xdist
|
||||||
|
|
||||||
|
The examples above are already compatible with parallel test execution with `pytest-xdist`.
|
||||||
|
|
||||||
|
## Cleaning up tables
|
||||||
|
|
||||||
|
If you want to manually clean up a database you can use [`cleanup_tables`](../reference/pytest.md#fastapi_toolsets.pytest.utils.cleanup_tables), this will truncates all tables between tests for fast isolation:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import cleanup_tables
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
async def clean(db_session):
|
||||||
|
yield
|
||||||
|
await cleanup_tables(session=db_session, base=Base)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/pytest.md)
|
||||||
51
docs/module/schemas.md
Normal file
51
docs/module/schemas.md
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
# Schemas
|
||||||
|
|
||||||
|
Standardized Pydantic response models for consistent API responses across your FastAPI application.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `schemas` module provides generic response wrappers that enforce a uniform response structure. All models use `from_attributes=True` for ORM compatibility and `validate_assignment=True` for runtime type safety.
|
||||||
|
|
||||||
|
## Response models
|
||||||
|
|
||||||
|
### [`Response[T]`](../reference/schemas.md#fastapi_toolsets.schemas.Response)
|
||||||
|
|
||||||
|
The most common wrapper for a single resource response.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.schemas import Response
|
||||||
|
|
||||||
|
@router.get("/users/{id}")
|
||||||
|
async def get_user(user: User = UserDep) -> Response[UserSchema]:
|
||||||
|
return Response(data=user, message="User retrieved")
|
||||||
|
```
|
||||||
|
|
||||||
|
### [`PaginatedResponse[T]`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse)
|
||||||
|
|
||||||
|
Wraps a list of items with pagination metadata and optional facet values.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.schemas import PaginatedResponse, Pagination
|
||||||
|
|
||||||
|
@router.get("/users")
|
||||||
|
async def list_users() -> PaginatedResponse[UserSchema]:
|
||||||
|
return PaginatedResponse(
|
||||||
|
data=users,
|
||||||
|
pagination=Pagination(
|
||||||
|
total_count=100,
|
||||||
|
items_per_page=10,
|
||||||
|
page=1,
|
||||||
|
has_more=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
The optional `filter_attributes` field is populated when `facet_fields` are configured on the CRUD class (see [Filter attributes](crud.md#filter-attributes-facets)). It is `None` by default and can be hidden from API responses with `response_model_exclude_none=True`.
|
||||||
|
|
||||||
|
### [`ErrorResponse`](../reference/schemas.md#fastapi_toolsets.schemas.ErrorResponse)
|
||||||
|
|
||||||
|
Returned automatically by the exceptions handler.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/schemas.md)
|
||||||
7
docs/overrides/main.html
Normal file
7
docs/overrides/main.html
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
{% extends "base.html" %} {% block extrahead %}
|
||||||
|
<script
|
||||||
|
defer
|
||||||
|
src="https://analytics.d3vyce.fr/script.js"
|
||||||
|
data-website-id="338b8816-7b99-4c6a-82f3-15595be3fd47"
|
||||||
|
></script>
|
||||||
|
{{ super() }} {% endblock %}
|
||||||
27
docs/reference/cli.md
Normal file
27
docs/reference/cli.md
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
# `cli`
|
||||||
|
|
||||||
|
Here's the reference for the CLI configuration helpers used to load settings from `pyproject.toml`.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.cli.config`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.cli.config import (
|
||||||
|
import_from_string,
|
||||||
|
get_config_value,
|
||||||
|
get_fixtures_registry,
|
||||||
|
get_db_context,
|
||||||
|
get_custom_cli,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.config.import_from_string
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.config.get_config_value
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.config.get_fixtures_registry
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.config.get_db_context
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.config.get_custom_cli
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.utils.async_command
|
||||||
20
docs/reference/crud.md
Normal file
20
docs/reference/crud.md
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
# `crud`
|
||||||
|
|
||||||
|
Here's the reference for the CRUD classes, factory, and search utilities.
|
||||||
|
|
||||||
|
You can import the main symbols from `fastapi_toolsets.crud`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.crud import CrudFactory, AsyncCrud
|
||||||
|
from fastapi_toolsets.crud.search import SearchConfig, get_searchable_fields, build_search_filters
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.crud.factory.AsyncCrud
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.crud.factory.CrudFactory
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.crud.search.SearchConfig
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.crud.search.get_searchable_fields
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.crud.search.build_search_filters
|
||||||
28
docs/reference/db.md
Normal file
28
docs/reference/db.md
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# `db`
|
||||||
|
|
||||||
|
Here's the reference for all database session utilities, transaction helpers, and locking functions.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.db`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import (
|
||||||
|
LockMode,
|
||||||
|
create_db_dependency,
|
||||||
|
create_db_context,
|
||||||
|
get_transaction,
|
||||||
|
lock_tables,
|
||||||
|
wait_for_row_change,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.LockMode
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.create_db_dependency
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.create_db_context
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.get_transaction
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.lock_tables
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.wait_for_row_change
|
||||||
13
docs/reference/dependencies.md
Normal file
13
docs/reference/dependencies.md
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# `dependencies`
|
||||||
|
|
||||||
|
Here's the reference for the FastAPI dependency factory functions.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.dependencies`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.dependencies import PathDependency, BodyDependency
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.dependencies.PathDependency
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.dependencies.BodyDependency
|
||||||
40
docs/reference/exceptions.md
Normal file
40
docs/reference/exceptions.md
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
# `exceptions`
|
||||||
|
|
||||||
|
Here's the reference for all exception classes and handler utilities.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.exceptions`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.exceptions import (
|
||||||
|
ApiException,
|
||||||
|
UnauthorizedError,
|
||||||
|
ForbiddenError,
|
||||||
|
NotFoundError,
|
||||||
|
ConflictError,
|
||||||
|
NoSearchableFieldsError,
|
||||||
|
InvalidFacetFilterError,
|
||||||
|
InvalidOrderFieldError,
|
||||||
|
generate_error_responses,
|
||||||
|
init_exceptions_handlers,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.ApiException
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.UnauthorizedError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.ForbiddenError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.NotFoundError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.ConflictError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.NoSearchableFieldsError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.InvalidOrderFieldError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.generate_error_responses
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.handler.init_exceptions_handlers
|
||||||
31
docs/reference/fixtures.md
Normal file
31
docs/reference/fixtures.md
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
# `fixtures`
|
||||||
|
|
||||||
|
Here's the reference for the fixture registry, enums, and loading utilities.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.fixtures`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.fixtures import (
|
||||||
|
Context,
|
||||||
|
LoadStrategy,
|
||||||
|
Fixture,
|
||||||
|
FixtureRegistry,
|
||||||
|
load_fixtures,
|
||||||
|
load_fixtures_by_context,
|
||||||
|
get_obj_by_attr,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.enum.Context
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.enum.LoadStrategy
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.registry.Fixture
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.registry.FixtureRegistry
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.utils.load_fixtures
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.utils.load_fixtures_by_context
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.utils.get_obj_by_attr
|
||||||
13
docs/reference/logger.md
Normal file
13
docs/reference/logger.md
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# `logger`
|
||||||
|
|
||||||
|
Here's the reference for the logging utilities.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.logger`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.logger import configure_logging, get_logger
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.logger.configure_logging
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.logger.get_logger
|
||||||
15
docs/reference/metrics.md
Normal file
15
docs/reference/metrics.md
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# `metrics`
|
||||||
|
|
||||||
|
Here's the reference for the Prometheus metrics registry and endpoint handler.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.metrics`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.metrics import Metric, MetricsRegistry, init_metrics
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.metrics.registry.Metric
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.metrics.registry.MetricsRegistry
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.metrics.handler.init_metrics
|
||||||
28
docs/reference/pytest.md
Normal file
28
docs/reference/pytest.md
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# `pytest`
|
||||||
|
|
||||||
|
Here's the reference for all testing utilities and pytest fixtures.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.pytest`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import (
|
||||||
|
register_fixtures,
|
||||||
|
create_async_client,
|
||||||
|
create_db_session,
|
||||||
|
worker_database_url,
|
||||||
|
create_worker_database,
|
||||||
|
cleanup_tables,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.plugin.register_fixtures
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.utils.create_async_client
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.utils.create_db_session
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.utils.worker_database_url
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.utils.create_worker_database
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.utils.cleanup_tables
|
||||||
34
docs/reference/schemas.md
Normal file
34
docs/reference/schemas.md
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
# `schemas` module
|
||||||
|
|
||||||
|
Here's the reference for all response models and types provided by the `schemas` module.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.schemas`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.schemas import (
|
||||||
|
PydanticBase,
|
||||||
|
ResponseStatus,
|
||||||
|
ApiError,
|
||||||
|
BaseResponse,
|
||||||
|
Response,
|
||||||
|
ErrorResponse,
|
||||||
|
Pagination,
|
||||||
|
PaginatedResponse,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.PydanticBase
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.ResponseStatus
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.ApiError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.BaseResponse
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.Response
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.ErrorResponse
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.Pagination
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.PaginatedResponse
|
||||||
0
docs_src/__init__.py
Normal file
0
docs_src/__init__.py
Normal file
0
docs_src/examples/__init__.py
Normal file
0
docs_src/examples/__init__.py
Normal file
0
docs_src/examples/pagination_search/__init__.py
Normal file
0
docs_src/examples/pagination_search/__init__.py
Normal file
9
docs_src/examples/pagination_search/app.py
Normal file
9
docs_src/examples/pagination_search/app.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
||||||
|
|
||||||
|
from .routes import router
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app=app)
|
||||||
|
app.include_router(router=router)
|
||||||
21
docs_src/examples/pagination_search/crud.py
Normal file
21
docs_src/examples/pagination_search/crud.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
from fastapi_toolsets.crud import CrudFactory
|
||||||
|
|
||||||
|
from .models import Article, Category
|
||||||
|
|
||||||
|
ArticleCrud = CrudFactory(
|
||||||
|
model=Article,
|
||||||
|
cursor_column=Article.created_at,
|
||||||
|
searchable_fields=[ # default fields for full-text search
|
||||||
|
Article.title,
|
||||||
|
Article.body,
|
||||||
|
(Article.category, Category.name),
|
||||||
|
],
|
||||||
|
facet_fields=[ # fields exposed as filter dropdowns
|
||||||
|
Article.status,
|
||||||
|
(Article.category, Category.name),
|
||||||
|
],
|
||||||
|
order_fields=[ # fields exposed for client-driven ordering
|
||||||
|
Article.title,
|
||||||
|
Article.created_at,
|
||||||
|
],
|
||||||
|
)
|
||||||
17
docs_src/examples/pagination_search/db.py
Normal file
17
docs_src/examples/pagination_search/db.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from fastapi_toolsets.db import create_db_context, create_db_dependency
|
||||||
|
|
||||||
|
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/postgres"
|
||||||
|
|
||||||
|
engine = create_async_engine(url=DATABASE_URL, future=True)
|
||||||
|
async_session_maker = async_sessionmaker(bind=engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
get_db = create_db_dependency(session_maker=async_session_maker)
|
||||||
|
get_db_context = create_db_context(session_maker=async_session_maker)
|
||||||
|
|
||||||
|
|
||||||
|
SessionDep = Annotated[AsyncSession, Depends(get_db)]
|
||||||
36
docs_src/examples/pagination_search/models.py
Normal file
36
docs_src/examples/pagination_search/models.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
import datetime
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from sqlalchemy import Boolean, DateTime, ForeignKey, String, Text, func
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Category(Base):
|
||||||
|
__tablename__ = "categories"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(primary_key=True, default=uuid.uuid4)
|
||||||
|
name: Mapped[str] = mapped_column(String(64), unique=True)
|
||||||
|
|
||||||
|
articles: Mapped[list["Article"]] = relationship(back_populates="category")
|
||||||
|
|
||||||
|
|
||||||
|
class Article(Base):
|
||||||
|
__tablename__ = "articles"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(primary_key=True, default=uuid.uuid4)
|
||||||
|
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), server_default=func.now()
|
||||||
|
)
|
||||||
|
title: Mapped[str] = mapped_column(String(256))
|
||||||
|
body: Mapped[str] = mapped_column(Text)
|
||||||
|
status: Mapped[str] = mapped_column(String(32))
|
||||||
|
published: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||||
|
category_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
|
ForeignKey("categories.id"), nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
|
category: Mapped["Category | None"] = relationship(back_populates="articles")
|
||||||
59
docs_src/examples/pagination_search/routes.py
Normal file
59
docs_src/examples/pagination_search/routes.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
|
||||||
|
from fastapi_toolsets.crud import OrderByClause
|
||||||
|
from fastapi_toolsets.schemas import PaginatedResponse
|
||||||
|
|
||||||
|
from .crud import ArticleCrud
|
||||||
|
from .db import SessionDep
|
||||||
|
from .models import Article
|
||||||
|
from .schemas import ArticleRead
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/articles")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/offset")
|
||||||
|
async def list_articles_offset(
|
||||||
|
session: SessionDep,
|
||||||
|
filter_by: Annotated[dict[str, list[str]], Depends(ArticleCrud.filter_params())],
|
||||||
|
order_by: Annotated[
|
||||||
|
OrderByClause | None,
|
||||||
|
Depends(ArticleCrud.order_params(default_field=Article.created_at)),
|
||||||
|
],
|
||||||
|
page: int = Query(1, ge=1),
|
||||||
|
items_per_page: int = Query(20, ge=1, le=100),
|
||||||
|
search: str | None = None,
|
||||||
|
) -> PaginatedResponse[ArticleRead]:
|
||||||
|
return await ArticleCrud.offset_paginate(
|
||||||
|
session=session,
|
||||||
|
page=page,
|
||||||
|
items_per_page=items_per_page,
|
||||||
|
search=search,
|
||||||
|
filter_by=filter_by or None,
|
||||||
|
order_by=order_by,
|
||||||
|
schema=ArticleRead,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/cursor")
|
||||||
|
async def list_articles_cursor(
|
||||||
|
session: SessionDep,
|
||||||
|
filter_by: Annotated[dict[str, list[str]], Depends(ArticleCrud.filter_params())],
|
||||||
|
order_by: Annotated[
|
||||||
|
OrderByClause | None,
|
||||||
|
Depends(ArticleCrud.order_params(default_field=Article.created_at)),
|
||||||
|
],
|
||||||
|
cursor: str | None = None,
|
||||||
|
items_per_page: int = Query(20, ge=1, le=100),
|
||||||
|
search: str | None = None,
|
||||||
|
) -> PaginatedResponse[ArticleRead]:
|
||||||
|
return await ArticleCrud.cursor_paginate(
|
||||||
|
session=session,
|
||||||
|
cursor=cursor,
|
||||||
|
items_per_page=items_per_page,
|
||||||
|
search=search,
|
||||||
|
filter_by=filter_by or None,
|
||||||
|
order_by=order_by,
|
||||||
|
schema=ArticleRead,
|
||||||
|
)
|
||||||
13
docs_src/examples/pagination_search/schemas.py
Normal file
13
docs_src/examples/pagination_search/schemas.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
import datetime
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from fastapi_toolsets.schemas import PydanticBase
|
||||||
|
|
||||||
|
|
||||||
|
class ArticleRead(PydanticBase):
|
||||||
|
id: uuid.UUID
|
||||||
|
created_at: datetime.datetime
|
||||||
|
title: str
|
||||||
|
status: str
|
||||||
|
published: bool
|
||||||
|
category_id: uuid.UUID | None
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "fastapi-toolsets"
|
name = "fastapi-toolsets"
|
||||||
version = "0.4.0"
|
version = "1.3.0"
|
||||||
description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL"
|
description = "Production-ready utilities for FastAPI applications"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
license-files = ["LICENSE"]
|
license-files = ["LICENSE"]
|
||||||
@@ -11,7 +11,7 @@ authors = [
|
|||||||
]
|
]
|
||||||
keywords = ["fastapi", "sqlalchemy", "postgresql"]
|
keywords = ["fastapi", "sqlalchemy", "postgresql"]
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Development Status :: 4 - Beta",
|
"Development Status :: 5 - Production/Stable",
|
||||||
"Framework :: AsyncIO",
|
"Framework :: AsyncIO",
|
||||||
"Framework :: FastAPI",
|
"Framework :: FastAPI",
|
||||||
"Framework :: Pydantic",
|
"Framework :: Pydantic",
|
||||||
@@ -31,12 +31,10 @@ classifiers = [
|
|||||||
"Typing :: Typed",
|
"Typing :: Typed",
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"fastapi>=0.100.0",
|
|
||||||
"sqlalchemy[asyncio]>=2.0",
|
|
||||||
"asyncpg>=0.29.0",
|
"asyncpg>=0.29.0",
|
||||||
|
"fastapi>=0.100.0",
|
||||||
"pydantic>=2.0",
|
"pydantic>=2.0",
|
||||||
"typer>=0.9.0",
|
"sqlalchemy[asyncio]>=2.0",
|
||||||
"httpx>=0.25.0",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
@@ -46,23 +44,47 @@ Repository = "https://github.com/d3vyce/fastapi-toolsets"
|
|||||||
Issues = "https://github.com/d3vyce/fastapi-toolsets/issues"
|
Issues = "https://github.com/d3vyce/fastapi-toolsets/issues"
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
test = [
|
cli = [
|
||||||
"pytest>=8.0.0",
|
"typer>=0.9.0",
|
||||||
"pytest-anyio>=0.0.0",
|
|
||||||
"coverage>=7.0.0",
|
|
||||||
"pytest-cov>=4.0.0",
|
|
||||||
]
|
]
|
||||||
dev = [
|
metrics = [
|
||||||
"fastapi-toolsets[test]",
|
"prometheus_client>=0.20.0",
|
||||||
"ruff>=0.1.0",
|
]
|
||||||
"ty>=0.0.1a0",
|
pytest = [
|
||||||
|
"httpx>=0.25.0",
|
||||||
|
"pytest-xdist>=3.0.0",
|
||||||
|
"pytest>=8.0.0",
|
||||||
|
]
|
||||||
|
all = [
|
||||||
|
"fastapi-toolsets[cli,metrics,pytest]",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
fastapi-toolsets = "fastapi_toolsets.cli:app"
|
manager = "fastapi_toolsets.cli.app:cli"
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
dev = [
|
||||||
|
{include-group = "tests"},
|
||||||
|
{include-group = "docs"},
|
||||||
|
"fastapi-toolsets[all]",
|
||||||
|
"ruff>=0.1.0",
|
||||||
|
"ty>=0.0.1a0",
|
||||||
|
]
|
||||||
|
tests = [
|
||||||
|
"coverage>=7.0.0",
|
||||||
|
"httpx>=0.25.0",
|
||||||
|
"pytest-anyio>=0.0.0",
|
||||||
|
"pytest-cov>=4.0.0",
|
||||||
|
"pytest-xdist>=3.0.0",
|
||||||
|
"pytest>=8.0.0",
|
||||||
|
]
|
||||||
|
docs = [
|
||||||
|
"mkdocstrings-python>=2.0.2",
|
||||||
|
"zensical>=0.0.23",
|
||||||
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["uv_build>=0.9.26,<0.10.0"]
|
requires = ["uv_build>=0.10,<0.11.0"]
|
||||||
build-backend = "uv_build"
|
build-backend = "uv_build"
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
|
|||||||
@@ -21,4 +21,4 @@ Example usage:
|
|||||||
return Response(data={"user": user.username}, message="Success")
|
return Response(data={"user": user.username}, message="Success")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.4.0"
|
__version__ = "1.3.0"
|
||||||
|
|||||||
9
src/fastapi_toolsets/_imports.py
Normal file
9
src/fastapi_toolsets/_imports.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
"""Optional dependency helpers."""
|
||||||
|
|
||||||
|
|
||||||
|
def require_extra(package: str, extra: str) -> None:
|
||||||
|
"""Raise *ImportError* with an actionable install instruction."""
|
||||||
|
raise ImportError(
|
||||||
|
f"'{package}' is required to use this module. "
|
||||||
|
f"Install it with: pip install fastapi-toolsets[{extra}]"
|
||||||
|
)
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
"""CLI for FastAPI projects."""
|
"""CLI for FastAPI projects."""
|
||||||
|
|
||||||
from .app import app, register_command
|
from .utils import async_command
|
||||||
|
|
||||||
__all__ = ["app", "register_command"]
|
__all__ = ["async_command"]
|
||||||
|
|||||||
@@ -1,97 +1,37 @@
|
|||||||
"""Main CLI application."""
|
"""Main CLI application."""
|
||||||
|
|
||||||
import importlib.util
|
try:
|
||||||
import sys
|
import typer
|
||||||
from pathlib import Path
|
except ImportError:
|
||||||
from typing import Annotated
|
from .._imports import require_extra
|
||||||
|
|
||||||
import typer
|
require_extra(package="typer", extra="cli")
|
||||||
|
|
||||||
from .commands import fixtures
|
from ..logger import configure_logging
|
||||||
|
from .config import get_custom_cli
|
||||||
|
from .pyproject import load_pyproject
|
||||||
|
|
||||||
app = typer.Typer(
|
# Use custom CLI if configured, otherwise create default one
|
||||||
name="fastapi-utils",
|
_custom_cli = get_custom_cli()
|
||||||
help="CLI utilities for FastAPI projects.",
|
|
||||||
no_args_is_help=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Register built-in commands
|
if _custom_cli is not None:
|
||||||
app.add_typer(fixtures.app, name="fixtures")
|
cli = _custom_cli
|
||||||
|
else:
|
||||||
|
cli = typer.Typer(
|
||||||
|
name="manager",
|
||||||
|
help="CLI utilities for FastAPI projects.",
|
||||||
|
no_args_is_help=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
_config = load_pyproject()
|
||||||
|
if _config.get("fixtures") and _config.get("db_context"):
|
||||||
|
from .commands.fixtures import fixture_cli
|
||||||
|
|
||||||
|
cli.add_typer(fixture_cli, name="fixtures")
|
||||||
|
|
||||||
|
|
||||||
def register_command(command: typer.Typer, name: str) -> None:
|
@cli.callback()
|
||||||
"""Register a custom command group.
|
def main(ctx: typer.Context) -> None:
|
||||||
|
|
||||||
Args:
|
|
||||||
command: Typer app for the command group
|
|
||||||
name: Name for the command group
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# In your project's cli.py:
|
|
||||||
import typer
|
|
||||||
from fastapi_toolsets.cli import app, register_command
|
|
||||||
|
|
||||||
my_commands = typer.Typer()
|
|
||||||
|
|
||||||
@my_commands.command()
|
|
||||||
def seed():
|
|
||||||
'''Seed the database.'''
|
|
||||||
...
|
|
||||||
|
|
||||||
register_command(my_commands, "db")
|
|
||||||
# Now available as: fastapi-utils db seed
|
|
||||||
"""
|
|
||||||
app.add_typer(command, name=name)
|
|
||||||
|
|
||||||
|
|
||||||
@app.callback()
|
|
||||||
def main(
|
|
||||||
ctx: typer.Context,
|
|
||||||
config: Annotated[
|
|
||||||
Path | None,
|
|
||||||
typer.Option(
|
|
||||||
"--config",
|
|
||||||
"-c",
|
|
||||||
help="Path to project config file (Python module with fixtures registry).",
|
|
||||||
envvar="FASTAPI_TOOLSETS_CONFIG",
|
|
||||||
),
|
|
||||||
] = None,
|
|
||||||
) -> None:
|
|
||||||
"""FastAPI utilities CLI."""
|
"""FastAPI utilities CLI."""
|
||||||
|
configure_logging()
|
||||||
ctx.ensure_object(dict)
|
ctx.ensure_object(dict)
|
||||||
|
|
||||||
if config:
|
|
||||||
ctx.obj["config_path"] = config
|
|
||||||
# Load the config module
|
|
||||||
config_module = _load_module_from_path(config)
|
|
||||||
ctx.obj["config_module"] = config_module
|
|
||||||
|
|
||||||
|
|
||||||
def _load_module_from_path(path: Path) -> object:
|
|
||||||
"""Load a Python module from a file path.
|
|
||||||
|
|
||||||
Handles both absolute and relative imports by adding the config's
|
|
||||||
parent directory to sys.path temporarily.
|
|
||||||
"""
|
|
||||||
path = path.resolve()
|
|
||||||
|
|
||||||
# Add the parent directory to sys.path to support relative imports
|
|
||||||
parent_dir = str(
|
|
||||||
path.parent.parent
|
|
||||||
) # Go up two levels (e.g., from app/cli_config.py to project root)
|
|
||||||
if parent_dir not in sys.path:
|
|
||||||
sys.path.insert(0, parent_dir)
|
|
||||||
|
|
||||||
# Also add immediate parent for direct module imports
|
|
||||||
immediate_parent = str(path.parent)
|
|
||||||
if immediate_parent not in sys.path:
|
|
||||||
sys.path.insert(0, immediate_parent)
|
|
||||||
|
|
||||||
spec = importlib.util.spec_from_file_location("config", path)
|
|
||||||
if spec is None or spec.loader is None:
|
|
||||||
raise typer.BadParameter(f"Cannot load module from {path}")
|
|
||||||
|
|
||||||
module = importlib.util.module_from_spec(spec)
|
|
||||||
sys.modules["config"] = module
|
|
||||||
spec.loader.exec_module(module)
|
|
||||||
return module
|
|
||||||
|
|||||||
@@ -1,138 +1,66 @@
|
|||||||
"""Fixture management commands."""
|
"""Fixture management commands."""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
from ...fixtures import Context, FixtureRegistry, LoadStrategy, load_fixtures_by_context
|
from ...fixtures import Context, LoadStrategy, load_fixtures_by_context
|
||||||
|
from ..config import get_db_context, get_fixtures_registry
|
||||||
|
from ..utils import async_command
|
||||||
|
|
||||||
app = typer.Typer(
|
fixture_cli = typer.Typer(
|
||||||
name="fixtures",
|
name="fixtures",
|
||||||
help="Manage database fixtures.",
|
help="Manage database fixtures.",
|
||||||
no_args_is_help=True,
|
no_args_is_help=True,
|
||||||
)
|
)
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
def _get_registry(ctx: typer.Context) -> FixtureRegistry:
|
@fixture_cli.command("list")
|
||||||
"""Get fixture registry from context."""
|
|
||||||
config = ctx.obj.get("config_module") if ctx.obj else None
|
|
||||||
if config is None:
|
|
||||||
raise typer.BadParameter(
|
|
||||||
"No config provided. Use --config to specify a config file with a 'fixtures' registry."
|
|
||||||
)
|
|
||||||
|
|
||||||
registry = getattr(config, "fixtures", None)
|
|
||||||
if registry is None:
|
|
||||||
raise typer.BadParameter(
|
|
||||||
"Config module must have a 'fixtures' attribute (FixtureRegistry instance)."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not isinstance(registry, FixtureRegistry):
|
|
||||||
raise typer.BadParameter(
|
|
||||||
f"'fixtures' must be a FixtureRegistry instance, got {type(registry).__name__}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return registry
|
|
||||||
|
|
||||||
|
|
||||||
def _get_db_context(ctx: typer.Context):
|
|
||||||
"""Get database context manager from config."""
|
|
||||||
config = ctx.obj.get("config_module") if ctx.obj else None
|
|
||||||
if config is None:
|
|
||||||
raise typer.BadParameter("No config provided.")
|
|
||||||
|
|
||||||
get_db_context = getattr(config, "get_db_context", None)
|
|
||||||
if get_db_context is None:
|
|
||||||
raise typer.BadParameter("Config module must have a 'get_db_context' function.")
|
|
||||||
|
|
||||||
return get_db_context
|
|
||||||
|
|
||||||
|
|
||||||
@app.command("list")
|
|
||||||
def list_fixtures(
|
def list_fixtures(
|
||||||
ctx: typer.Context,
|
ctx: typer.Context,
|
||||||
context: Annotated[
|
context: Annotated[
|
||||||
str | None,
|
Context | None,
|
||||||
typer.Option(
|
typer.Option(
|
||||||
"--context",
|
"--context",
|
||||||
"-c",
|
"-c",
|
||||||
help="Filter by context (base, production, development, testing).",
|
help="Filter by context.",
|
||||||
),
|
),
|
||||||
] = None,
|
] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""List all registered fixtures."""
|
"""List all registered fixtures."""
|
||||||
registry = _get_registry(ctx)
|
registry = get_fixtures_registry()
|
||||||
|
fixtures = registry.get_by_context(context.value) if context else registry.get_all()
|
||||||
if context:
|
|
||||||
fixtures = registry.get_by_context(context)
|
|
||||||
else:
|
|
||||||
fixtures = registry.get_all()
|
|
||||||
|
|
||||||
if not fixtures:
|
if not fixtures:
|
||||||
typer.echo("No fixtures found.")
|
print("No fixtures found.")
|
||||||
return
|
return
|
||||||
|
|
||||||
typer.echo(f"\n{'Name':<30} {'Contexts':<30} {'Dependencies'}")
|
table = Table("Name", "Contexts", "Dependencies")
|
||||||
typer.echo("-" * 80)
|
|
||||||
|
|
||||||
for fixture in fixtures:
|
for fixture in fixtures:
|
||||||
contexts = ", ".join(fixture.contexts)
|
contexts = ", ".join(fixture.contexts)
|
||||||
deps = ", ".join(fixture.depends_on) if fixture.depends_on else "-"
|
deps = ", ".join(fixture.depends_on) if fixture.depends_on else "-"
|
||||||
typer.echo(f"{fixture.name:<30} {contexts:<30} {deps}")
|
table.add_row(fixture.name, contexts, deps)
|
||||||
|
|
||||||
typer.echo(f"\nTotal: {len(fixtures)} fixture(s)")
|
console.print(table)
|
||||||
|
print(f"\nTotal: {len(fixtures)} fixture(s)")
|
||||||
|
|
||||||
|
|
||||||
@app.command("graph")
|
@fixture_cli.command("load")
|
||||||
def show_graph(
|
@async_command
|
||||||
ctx: typer.Context,
|
async def load(
|
||||||
fixture_name: Annotated[
|
|
||||||
str | None,
|
|
||||||
typer.Argument(help="Show dependencies for a specific fixture."),
|
|
||||||
] = None,
|
|
||||||
) -> None:
|
|
||||||
"""Show fixture dependency graph."""
|
|
||||||
registry = _get_registry(ctx)
|
|
||||||
|
|
||||||
if fixture_name:
|
|
||||||
try:
|
|
||||||
order = registry.resolve_dependencies(fixture_name)
|
|
||||||
typer.echo(f"\nDependency chain for '{fixture_name}':\n")
|
|
||||||
for i, name in enumerate(order):
|
|
||||||
indent = " " * i
|
|
||||||
arrow = "└─> " if i > 0 else ""
|
|
||||||
typer.echo(f"{indent}{arrow}{name}")
|
|
||||||
except KeyError:
|
|
||||||
typer.echo(f"Fixture '{fixture_name}' not found.", err=True)
|
|
||||||
raise typer.Exit(1)
|
|
||||||
else:
|
|
||||||
# Show full graph
|
|
||||||
fixtures = registry.get_all()
|
|
||||||
|
|
||||||
typer.echo("\nFixture Dependency Graph:\n")
|
|
||||||
for fixture in fixtures:
|
|
||||||
deps = (
|
|
||||||
f" -> [{', '.join(fixture.depends_on)}]" if fixture.depends_on else ""
|
|
||||||
)
|
|
||||||
typer.echo(f" {fixture.name}{deps}")
|
|
||||||
|
|
||||||
|
|
||||||
@app.command("load")
|
|
||||||
def load(
|
|
||||||
ctx: typer.Context,
|
ctx: typer.Context,
|
||||||
contexts: Annotated[
|
contexts: Annotated[
|
||||||
list[str] | None,
|
list[Context] | None,
|
||||||
typer.Argument(
|
typer.Argument(help="Contexts to load."),
|
||||||
help="Contexts to load (base, production, development, testing)."
|
|
||||||
),
|
|
||||||
] = None,
|
] = None,
|
||||||
strategy: Annotated[
|
strategy: Annotated[
|
||||||
str,
|
LoadStrategy,
|
||||||
typer.Option(
|
typer.Option("--strategy", "-s", help="Load strategy."),
|
||||||
"--strategy", "-s", help="Load strategy: merge, insert, skip_existing."
|
] = LoadStrategy.MERGE,
|
||||||
),
|
|
||||||
] = "merge",
|
|
||||||
dry_run: Annotated[
|
dry_run: Annotated[
|
||||||
bool,
|
bool,
|
||||||
typer.Option(
|
typer.Option(
|
||||||
@@ -141,85 +69,32 @@ def load(
|
|||||||
] = False,
|
] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Load fixtures into the database."""
|
"""Load fixtures into the database."""
|
||||||
registry = _get_registry(ctx)
|
registry = get_fixtures_registry()
|
||||||
get_db_context = _get_db_context(ctx)
|
db_context = get_db_context()
|
||||||
|
|
||||||
# Parse contexts
|
context_list = [c.value for c in contexts] if contexts else [Context.BASE]
|
||||||
if contexts:
|
|
||||||
context_list = contexts
|
|
||||||
else:
|
|
||||||
context_list = [Context.BASE]
|
|
||||||
|
|
||||||
# Parse strategy
|
|
||||||
try:
|
|
||||||
load_strategy = LoadStrategy(strategy)
|
|
||||||
except ValueError:
|
|
||||||
typer.echo(
|
|
||||||
f"Invalid strategy: {strategy}. Use: merge, insert, skip_existing", err=True
|
|
||||||
)
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
# Resolve what will be loaded
|
|
||||||
ordered = registry.resolve_context_dependencies(*context_list)
|
ordered = registry.resolve_context_dependencies(*context_list)
|
||||||
|
|
||||||
if not ordered:
|
if not ordered:
|
||||||
typer.echo("No fixtures to load for the specified context(s).")
|
print("No fixtures to load for the specified context(s).")
|
||||||
return
|
return
|
||||||
|
|
||||||
typer.echo(f"\nFixtures to load ({load_strategy.value} strategy):")
|
print(f"\nFixtures to load ({strategy.value} strategy):")
|
||||||
for name in ordered:
|
for name in ordered:
|
||||||
fixture = registry.get(name)
|
fixture = registry.get(name)
|
||||||
instances = list(fixture.func())
|
instances = list(fixture.func())
|
||||||
model_name = type(instances[0]).__name__ if instances else "?"
|
model_name = type(instances[0]).__name__ if instances else "?"
|
||||||
typer.echo(f" - {name}: {len(instances)} {model_name}(s)")
|
print(f" - {name}: {len(instances)} {model_name}(s)")
|
||||||
|
|
||||||
if dry_run:
|
if dry_run:
|
||||||
typer.echo("\n[Dry run - no changes made]")
|
print("\n[Dry run - no changes made]")
|
||||||
return
|
return
|
||||||
|
|
||||||
typer.echo("\nLoading...")
|
async with db_context() as session:
|
||||||
|
result = await load_fixtures_by_context(
|
||||||
async def do_load():
|
session, registry, *context_list, strategy=strategy
|
||||||
async with get_db_context() as session:
|
)
|
||||||
result = await load_fixtures_by_context(
|
|
||||||
session, registry, *context_list, strategy=load_strategy
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
result = asyncio.run(do_load())
|
|
||||||
|
|
||||||
total = sum(len(items) for items in result.values())
|
total = sum(len(items) for items in result.values())
|
||||||
typer.echo(f"\nLoaded {total} record(s) successfully.")
|
print(f"\nLoaded {total} record(s) successfully.")
|
||||||
|
|
||||||
|
|
||||||
@app.command("show")
|
|
||||||
def show_fixture(
|
|
||||||
ctx: typer.Context,
|
|
||||||
name: Annotated[str, typer.Argument(help="Fixture name to show.")],
|
|
||||||
) -> None:
|
|
||||||
"""Show details of a specific fixture."""
|
|
||||||
registry = _get_registry(ctx)
|
|
||||||
|
|
||||||
try:
|
|
||||||
fixture = registry.get(name)
|
|
||||||
except KeyError:
|
|
||||||
typer.echo(f"Fixture '{name}' not found.", err=True)
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
typer.echo(f"\nFixture: {fixture.name}")
|
|
||||||
typer.echo(f"Contexts: {', '.join(fixture.contexts)}")
|
|
||||||
typer.echo(
|
|
||||||
f"Dependencies: {', '.join(fixture.depends_on) if fixture.depends_on else 'None'}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Show instances
|
|
||||||
instances = list(fixture.func())
|
|
||||||
if instances:
|
|
||||||
model_name = type(instances[0]).__name__
|
|
||||||
typer.echo(f"\nInstances ({len(instances)} {model_name}):")
|
|
||||||
for instance in instances[:10]: # Limit to 10
|
|
||||||
typer.echo(f" - {instance!r}")
|
|
||||||
if len(instances) > 10:
|
|
||||||
typer.echo(f" ... and {len(instances) - 10} more")
|
|
||||||
else:
|
|
||||||
typer.echo("\nNo instances (empty fixture)")
|
|
||||||
|
|||||||
125
src/fastapi_toolsets/cli/config.py
Normal file
125
src/fastapi_toolsets/cli/config.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
"""CLI configuration and dynamic imports."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
from typing import TYPE_CHECKING, Any, Literal, overload
|
||||||
|
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from .pyproject import find_pyproject, load_pyproject
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..fixtures import FixtureRegistry
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_project_in_path():
|
||||||
|
"""Add project root to sys.path if not installed in editable mode."""
|
||||||
|
pyproject = find_pyproject()
|
||||||
|
if pyproject:
|
||||||
|
project_root = str(pyproject.parent)
|
||||||
|
if project_root not in sys.path:
|
||||||
|
sys.path.insert(0, project_root)
|
||||||
|
|
||||||
|
|
||||||
|
def import_from_string(import_path: str) -> Any:
|
||||||
|
"""Import an object from a dotted string path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
import_path: Import path in ``"module.submodule:attribute"`` format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The imported attribute
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
typer.BadParameter: If the import path is invalid or import fails
|
||||||
|
"""
|
||||||
|
if ":" not in import_path:
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"Invalid import path '{import_path}'. Expected format: 'module:attribute'"
|
||||||
|
)
|
||||||
|
|
||||||
|
module_path, attr_name = import_path.rsplit(":", 1)
|
||||||
|
|
||||||
|
_ensure_project_in_path()
|
||||||
|
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(module_path)
|
||||||
|
except ImportError as e:
|
||||||
|
raise typer.BadParameter(f"Cannot import module '{module_path}': {e}")
|
||||||
|
|
||||||
|
if not hasattr(module, attr_name):
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"Module '{module_path}' has no attribute '{attr_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
return getattr(module, attr_name)
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_config_value(key: str, required: Literal[True]) -> Any: ... # pragma: no cover
|
||||||
|
@overload
|
||||||
|
def get_config_value(
|
||||||
|
key: str, required: bool = False
|
||||||
|
) -> Any | None: ... # pragma: no cover
|
||||||
|
def get_config_value(key: str, required: bool = False) -> Any | None:
|
||||||
|
"""Get a configuration value from pyproject.toml.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The configuration key in [tool.fastapi-toolsets].
|
||||||
|
required: If True, raises an error when the key is missing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The configuration value, or None if not found and not required.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
typer.BadParameter: If required=True and the key is missing.
|
||||||
|
"""
|
||||||
|
config = load_pyproject()
|
||||||
|
value = config.get(key)
|
||||||
|
|
||||||
|
if required and value is None:
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"No '{key}' configured. "
|
||||||
|
f"Add '{key}' to [tool.fastapi-toolsets] in pyproject.toml."
|
||||||
|
)
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def get_fixtures_registry() -> FixtureRegistry:
|
||||||
|
"""Import and return the fixtures registry from config."""
|
||||||
|
from ..fixtures import FixtureRegistry
|
||||||
|
|
||||||
|
import_path = get_config_value("fixtures", required=True)
|
||||||
|
registry = import_from_string(import_path)
|
||||||
|
|
||||||
|
if not isinstance(registry, FixtureRegistry):
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"'fixtures' must be a FixtureRegistry instance, got {type(registry).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return registry
|
||||||
|
|
||||||
|
|
||||||
|
def get_db_context() -> Any:
|
||||||
|
"""Import and return the db_context function from config."""
|
||||||
|
import_path = get_config_value("db_context", required=True)
|
||||||
|
return import_from_string(import_path)
|
||||||
|
|
||||||
|
|
||||||
|
def get_custom_cli() -> typer.Typer | None:
|
||||||
|
"""Import and return the custom CLI Typer instance from config."""
|
||||||
|
import_path = get_config_value("custom_cli")
|
||||||
|
if not import_path:
|
||||||
|
return None
|
||||||
|
|
||||||
|
custom = import_from_string(import_path)
|
||||||
|
|
||||||
|
if not isinstance(custom, typer.Typer):
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"'custom_cli' must be a Typer instance, got {type(custom).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return custom
|
||||||
49
src/fastapi_toolsets/cli/pyproject.py
Normal file
49
src/fastapi_toolsets/cli/pyproject.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
"""Pyproject.toml discovery and loading."""
|
||||||
|
|
||||||
|
import tomllib
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
TOOL_NAME = "fastapi-toolsets"
|
||||||
|
|
||||||
|
|
||||||
|
def find_pyproject(start_path: Path | None = None) -> Path | None:
|
||||||
|
"""Find pyproject.toml by walking up the directory tree.
|
||||||
|
|
||||||
|
Similar to how pytest, black, and ruff discover their config files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_path: Directory to start searching from. Defaults to cwd.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to pyproject.toml, or None if not found.
|
||||||
|
"""
|
||||||
|
path = (start_path or Path.cwd()).resolve()
|
||||||
|
|
||||||
|
for directory in [path, *path.parents]:
|
||||||
|
pyproject = directory / "pyproject.toml"
|
||||||
|
if pyproject.is_file():
|
||||||
|
return pyproject
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def load_pyproject(path: Path | None = None) -> dict:
|
||||||
|
"""Load tool configuration from pyproject.toml.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Explicit path to pyproject.toml. If None, searches up from cwd.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The [tool.fastapi-toolsets] section as a dict, or empty dict if not found.
|
||||||
|
"""
|
||||||
|
pyproject_path = path or find_pyproject()
|
||||||
|
|
||||||
|
if not pyproject_path:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(pyproject_path, "rb") as f:
|
||||||
|
data = tomllib.load(f)
|
||||||
|
return data.get("tool", {}).get(TOOL_NAME, {})
|
||||||
|
except (OSError, tomllib.TOMLDecodeError):
|
||||||
|
return {}
|
||||||
29
src/fastapi_toolsets/cli/utils.py
Normal file
29
src/fastapi_toolsets/cli/utils.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
"""CLI utility functions."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import functools
|
||||||
|
from collections.abc import Callable, Coroutine
|
||||||
|
from typing import Any, ParamSpec, TypeVar
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def async_command(func: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, T]:
|
||||||
|
"""Decorator to run an async function as a sync CLI command.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
@fixture_cli.command("load")
|
||||||
|
@async_command
|
||||||
|
async def load(ctx: typer.Context) -> None:
|
||||||
|
async with get_db_context() as session:
|
||||||
|
await load_fixtures(session, registry)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||||
|
return asyncio.run(func(*args, **kwargs))
|
||||||
|
|
||||||
|
return wrapper
|
||||||
@@ -1,17 +1,21 @@
|
|||||||
"""Generic async CRUD operations for SQLAlchemy models."""
|
"""Generic async CRUD operations for SQLAlchemy models."""
|
||||||
|
|
||||||
from ..exceptions import NoSearchableFieldsError
|
from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError
|
||||||
from .factory import CrudFactory
|
from .factory import CrudFactory, JoinType, M2MFieldType, OrderByClause
|
||||||
from .search import (
|
from .search import (
|
||||||
|
FacetFieldType,
|
||||||
SearchConfig,
|
SearchConfig,
|
||||||
SearchFieldType,
|
|
||||||
get_searchable_fields,
|
get_searchable_fields,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CrudFactory",
|
"CrudFactory",
|
||||||
|
"FacetFieldType",
|
||||||
"get_searchable_fields",
|
"get_searchable_fields",
|
||||||
|
"InvalidFacetFilterError",
|
||||||
|
"JoinType",
|
||||||
|
"M2MFieldType",
|
||||||
"NoSearchableFieldsError",
|
"NoSearchableFieldsError",
|
||||||
|
"OrderByClause",
|
||||||
"SearchConfig",
|
"SearchConfig",
|
||||||
"SearchFieldType",
|
|
||||||
]
|
]
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,19 +1,23 @@
|
|||||||
"""Search utilities for AsyncCrud."""
|
"""Search utilities for AsyncCrud."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections import Counter
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Literal
|
from typing import TYPE_CHECKING, Any, Literal
|
||||||
|
|
||||||
from sqlalchemy import String, or_
|
from sqlalchemy import String, or_, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||||
|
|
||||||
from ..exceptions import NoSearchableFieldsError
|
from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sqlalchemy.sql.elements import ColumnElement
|
from sqlalchemy.sql.elements import ColumnElement
|
||||||
|
|
||||||
SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...]
|
SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...]
|
||||||
|
FacetFieldType = SearchFieldType
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -89,6 +93,9 @@ def build_search_filters(
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (filter_conditions, joins_needed)
|
Tuple of (filter_conditions, joins_needed)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NoSearchableFieldsError: If no searchable field has been configured
|
||||||
"""
|
"""
|
||||||
# Normalize input
|
# Normalize input
|
||||||
if isinstance(search, str):
|
if isinstance(search, str):
|
||||||
@@ -129,13 +136,14 @@ def build_search_filters(
|
|||||||
else:
|
else:
|
||||||
column = field
|
column = field
|
||||||
|
|
||||||
# Build the filter
|
# Build the filter (cast to String for non-text columns)
|
||||||
|
column_as_string = column.cast(String)
|
||||||
if config.case_sensitive:
|
if config.case_sensitive:
|
||||||
filters.append(column.like(f"%{query}%"))
|
filters.append(column_as_string.like(f"%{query}%"))
|
||||||
else:
|
else:
|
||||||
filters.append(column.ilike(f"%{query}%"))
|
filters.append(column_as_string.ilike(f"%{query}%"))
|
||||||
|
|
||||||
if not filters:
|
if not filters: # pragma: no cover
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
# Combine based on match_mode
|
# Combine based on match_mode
|
||||||
@@ -143,3 +151,145 @@ def build_search_filters(
|
|||||||
return [or_(*filters)], joins
|
return [or_(*filters)], joins
|
||||||
else:
|
else:
|
||||||
return filters, joins
|
return filters, joins
|
||||||
|
|
||||||
|
|
||||||
|
def facet_keys(facet_fields: Sequence[FacetFieldType]) -> list[str]:
|
||||||
|
"""Return a key for each facet field, disambiguating duplicate column keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
facet_fields: Sequence of facet fields — either direct columns or
|
||||||
|
relationship tuples ``(rel, ..., column)``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of string keys, one per facet field, in the same order.
|
||||||
|
"""
|
||||||
|
raw: list[tuple[str, str | None]] = []
|
||||||
|
for field in facet_fields:
|
||||||
|
if isinstance(field, tuple):
|
||||||
|
rel = field[-2]
|
||||||
|
column = field[-1]
|
||||||
|
raw.append((column.key, rel.key))
|
||||||
|
else:
|
||||||
|
raw.append((field.key, None))
|
||||||
|
|
||||||
|
counts = Counter(col_key for col_key, _ in raw)
|
||||||
|
keys: list[str] = []
|
||||||
|
for col_key, rel_key in raw:
|
||||||
|
if counts[col_key] > 1 and rel_key is not None:
|
||||||
|
keys.append(f"{rel_key}__{col_key}")
|
||||||
|
else:
|
||||||
|
keys.append(col_key)
|
||||||
|
return keys
|
||||||
|
|
||||||
|
|
||||||
|
async def build_facets(
|
||||||
|
session: "AsyncSession",
|
||||||
|
model: type[DeclarativeBase],
|
||||||
|
facet_fields: Sequence[FacetFieldType],
|
||||||
|
*,
|
||||||
|
base_filters: "list[ColumnElement[bool]] | None" = None,
|
||||||
|
base_joins: list[InstrumentedAttribute[Any]] | None = None,
|
||||||
|
) -> dict[str, list[Any]]:
|
||||||
|
"""Return distinct values for each facet field, respecting current filters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: DB async session
|
||||||
|
model: SQLAlchemy model class
|
||||||
|
facet_fields: Columns or relationship tuples to facet on
|
||||||
|
base_filters: Filter conditions already applied to the main query (search + caller filters)
|
||||||
|
base_joins: Relationship joins already applied to the main query
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping column key to sorted list of distinct non-None values
|
||||||
|
"""
|
||||||
|
existing_join_keys: set[str] = {str(j) for j in (base_joins or [])}
|
||||||
|
|
||||||
|
keys = facet_keys(facet_fields)
|
||||||
|
|
||||||
|
async def _query_facet(field: FacetFieldType, key: str) -> tuple[str, list[Any]]:
|
||||||
|
if isinstance(field, tuple):
|
||||||
|
# Relationship chain: (User.role, Role.name) — last element is the column
|
||||||
|
rels = field[:-1]
|
||||||
|
column = field[-1]
|
||||||
|
else:
|
||||||
|
rels = ()
|
||||||
|
column = field
|
||||||
|
|
||||||
|
q = select(column).select_from(model).distinct()
|
||||||
|
|
||||||
|
# Apply base joins (already done on main query, but needed here independently)
|
||||||
|
for rel in base_joins or []:
|
||||||
|
q = q.outerjoin(rel)
|
||||||
|
|
||||||
|
# Add any extra joins required by this facet field that aren't already in base_joins
|
||||||
|
for rel in rels:
|
||||||
|
if str(rel) not in existing_join_keys:
|
||||||
|
q = q.outerjoin(rel)
|
||||||
|
|
||||||
|
if base_filters:
|
||||||
|
from sqlalchemy import and_
|
||||||
|
|
||||||
|
q = q.where(and_(*base_filters))
|
||||||
|
|
||||||
|
q = q.order_by(column)
|
||||||
|
result = await session.execute(q)
|
||||||
|
values = [row[0] for row in result.all() if row[0] is not None]
|
||||||
|
return key, values
|
||||||
|
|
||||||
|
pairs = await asyncio.gather(
|
||||||
|
*[_query_facet(f, k) for f, k in zip(facet_fields, keys)]
|
||||||
|
)
|
||||||
|
return dict(pairs)
|
||||||
|
|
||||||
|
|
||||||
|
def build_filter_by(
|
||||||
|
filter_by: dict[str, Any],
|
||||||
|
facet_fields: Sequence[FacetFieldType],
|
||||||
|
) -> tuple["list[ColumnElement[bool]]", list[InstrumentedAttribute[Any]]]:
|
||||||
|
"""Translate a {column_key: value} dict into SQLAlchemy filter conditions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter_by: Mapping of column key to scalar value or list of values
|
||||||
|
facet_fields: Declared facet fields to validate keys against
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (filter_conditions, joins_needed)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidFacetFilterError: If a key in filter_by is not a declared facet field
|
||||||
|
"""
|
||||||
|
index: dict[
|
||||||
|
str, tuple[InstrumentedAttribute[Any], list[InstrumentedAttribute[Any]]]
|
||||||
|
] = {}
|
||||||
|
for key, field in zip(facet_keys(facet_fields), facet_fields):
|
||||||
|
if isinstance(field, tuple):
|
||||||
|
rels = list(field[:-1])
|
||||||
|
column = field[-1]
|
||||||
|
else:
|
||||||
|
rels = []
|
||||||
|
column = field
|
||||||
|
index[key] = (column, rels)
|
||||||
|
|
||||||
|
valid_keys = set(index)
|
||||||
|
filters: list[ColumnElement[bool]] = []
|
||||||
|
joins: list[InstrumentedAttribute[Any]] = []
|
||||||
|
added_join_keys: set[str] = set()
|
||||||
|
|
||||||
|
for key, value in filter_by.items():
|
||||||
|
if key not in index:
|
||||||
|
raise InvalidFacetFilterError(key, valid_keys)
|
||||||
|
|
||||||
|
column, rels = index[key]
|
||||||
|
|
||||||
|
for rel in rels:
|
||||||
|
rel_key = str(rel)
|
||||||
|
if rel_key not in added_join_keys:
|
||||||
|
joins.append(rel)
|
||||||
|
added_join_keys.add(rel_key)
|
||||||
|
|
||||||
|
if isinstance(value, list):
|
||||||
|
filters.append(column.in_(value))
|
||||||
|
else:
|
||||||
|
filters.append(column == value)
|
||||||
|
|
||||||
|
return filters, joins
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
"""Database utilities: sessions, transactions, and locks."""
|
"""Database utilities: sessions, transactions, and locks."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from collections.abc import AsyncGenerator, Callable
|
from collections.abc import AsyncGenerator, Callable
|
||||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
@@ -14,6 +16,7 @@ __all__ = [
|
|||||||
"create_db_dependency",
|
"create_db_dependency",
|
||||||
"lock_tables",
|
"lock_tables",
|
||||||
"get_transaction",
|
"get_transaction",
|
||||||
|
"wait_for_row_change",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -32,6 +35,7 @@ def create_db_dependency(
|
|||||||
An async generator function usable with FastAPI's Depends()
|
An async generator function usable with FastAPI's Depends()
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||||
from fastapi_toolsets.db import create_db_dependency
|
from fastapi_toolsets.db import create_db_dependency
|
||||||
@@ -43,6 +47,7 @@ def create_db_dependency(
|
|||||||
@app.get("/users")
|
@app.get("/users")
|
||||||
async def list_users(session: AsyncSession = Depends(get_db)):
|
async def list_users(session: AsyncSession = Depends(get_db)):
|
||||||
...
|
...
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||||
@@ -69,6 +74,7 @@ def create_db_context(
|
|||||||
An async context manager function
|
An async context manager function
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||||
from fastapi_toolsets.db import create_db_context
|
from fastapi_toolsets.db import create_db_context
|
||||||
|
|
||||||
@@ -80,6 +86,7 @@ def create_db_context(
|
|||||||
async with get_db_context() as session:
|
async with get_db_context() as session:
|
||||||
user = await UserCrud.get(session, [User.id == 1])
|
user = await UserCrud.get(session, [User.id == 1])
|
||||||
...
|
...
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
get_db = create_db_dependency(session_maker)
|
get_db = create_db_dependency(session_maker)
|
||||||
return asynccontextmanager(get_db)
|
return asynccontextmanager(get_db)
|
||||||
@@ -101,9 +108,11 @@ async def get_transaction(
|
|||||||
The session within the transaction context
|
The session within the transaction context
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
async with get_transaction(session):
|
async with get_transaction(session):
|
||||||
session.add(model)
|
session.add(model)
|
||||||
# Auto-commits on exit, rolls back on exception
|
# Auto-commits on exit, rolls back on exception
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
if session.in_transaction():
|
if session.in_transaction():
|
||||||
async with session.begin_nested():
|
async with session.begin_nested():
|
||||||
@@ -155,6 +164,7 @@ async def lock_tables(
|
|||||||
SQLAlchemyError: If lock cannot be acquired within timeout
|
SQLAlchemyError: If lock cannot be acquired within timeout
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
from fastapi_toolsets.db import lock_tables, LockMode
|
from fastapi_toolsets.db import lock_tables, LockMode
|
||||||
|
|
||||||
async with lock_tables(session, [User, Account]):
|
async with lock_tables(session, [User, Account]):
|
||||||
@@ -166,6 +176,7 @@ async def lock_tables(
|
|||||||
async with lock_tables(session, [Order], mode=LockMode.EXCLUSIVE):
|
async with lock_tables(session, [Order], mode=LockMode.EXCLUSIVE):
|
||||||
# Exclusive lock - no other transactions can access
|
# Exclusive lock - no other transactions can access
|
||||||
await process_order(session, order_id)
|
await process_order(session, order_id)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
table_names = ",".join(table.__tablename__ for table in tables)
|
table_names = ",".join(table.__tablename__ for table in tables)
|
||||||
|
|
||||||
@@ -173,3 +184,85 @@ async def lock_tables(
|
|||||||
await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
|
await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
|
||||||
await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
|
await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
_M = TypeVar("_M", bound=DeclarativeBase)
|
||||||
|
|
||||||
|
|
||||||
|
async def wait_for_row_change(
|
||||||
|
session: AsyncSession,
|
||||||
|
model: type[_M],
|
||||||
|
pk_value: Any,
|
||||||
|
*,
|
||||||
|
columns: list[str] | None = None,
|
||||||
|
interval: float = 0.5,
|
||||||
|
timeout: float | None = None,
|
||||||
|
) -> _M:
|
||||||
|
"""Poll a database row until a change is detected.
|
||||||
|
|
||||||
|
Queries the row every ``interval`` seconds and returns the model instance
|
||||||
|
once a change is detected in any column (or only the specified ``columns``).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: AsyncSession instance
|
||||||
|
model: SQLAlchemy model class
|
||||||
|
pk_value: Primary key value of the row to watch
|
||||||
|
columns: Optional list of column names to watch. If None, all columns
|
||||||
|
are watched.
|
||||||
|
interval: Polling interval in seconds (default: 0.5)
|
||||||
|
timeout: Maximum time to wait in seconds. None means wait forever.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The refreshed model instance with updated values
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
LookupError: If the row does not exist or is deleted during polling
|
||||||
|
TimeoutError: If timeout expires before a change is detected
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import wait_for_row_change
|
||||||
|
|
||||||
|
# Wait for any column to change
|
||||||
|
updated = await wait_for_row_change(session, User, user_id)
|
||||||
|
|
||||||
|
# Watch specific columns with a timeout
|
||||||
|
updated = await wait_for_row_change(
|
||||||
|
session, User, user_id,
|
||||||
|
columns=["status", "email"],
|
||||||
|
interval=1.0,
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
instance = await session.get(model, pk_value)
|
||||||
|
if instance is None:
|
||||||
|
raise LookupError(f"{model.__name__} with pk={pk_value!r} not found")
|
||||||
|
|
||||||
|
if columns is not None:
|
||||||
|
watch_cols = columns
|
||||||
|
else:
|
||||||
|
watch_cols = [attr.key for attr in model.__mapper__.column_attrs]
|
||||||
|
|
||||||
|
initial = {col: getattr(instance, col) for col in watch_cols}
|
||||||
|
|
||||||
|
elapsed = 0.0
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
elapsed += interval
|
||||||
|
|
||||||
|
if timeout is not None and elapsed >= timeout:
|
||||||
|
raise TimeoutError(
|
||||||
|
f"No change detected on {model.__name__} "
|
||||||
|
f"with pk={pk_value!r} within {timeout}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
session.expunge(instance)
|
||||||
|
instance = await session.get(model, pk_value)
|
||||||
|
|
||||||
|
if instance is None:
|
||||||
|
raise LookupError(f"{model.__name__} with pk={pk_value!r} was deleted")
|
||||||
|
|
||||||
|
current = {col: getattr(instance, col) for col in watch_cols}
|
||||||
|
if current != initial:
|
||||||
|
return instance
|
||||||
|
|||||||
145
src/fastapi_toolsets/dependencies.py
Normal file
145
src/fastapi_toolsets/dependencies.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
"""Dependency factories for FastAPI routes."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from collections.abc import AsyncGenerator, Callable
|
||||||
|
from typing import Any, TypeVar, cast
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from .crud import CrudFactory
|
||||||
|
|
||||||
|
__all__ = ["BodyDependency", "PathDependency"]
|
||||||
|
|
||||||
|
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
||||||
|
SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]]
|
||||||
|
|
||||||
|
|
||||||
|
def PathDependency(
|
||||||
|
model: type[ModelType],
|
||||||
|
field: Any,
|
||||||
|
*,
|
||||||
|
session_dep: SessionDependency,
|
||||||
|
param_name: str | None = None,
|
||||||
|
) -> ModelType:
|
||||||
|
"""Create a dependency that fetches a DB object from a path parameter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: SQLAlchemy model class
|
||||||
|
field: Model field to filter by (e.g., User.id)
|
||||||
|
session_dep: Session dependency function (e.g., get_db)
|
||||||
|
param_name: Path parameter name (defaults to model_field, e.g., user_id)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Depends() instance that resolves to the model instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If no matching record is found
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
UserDep = PathDependency(User, User.id, session_dep=get_db)
|
||||||
|
|
||||||
|
@router.get("/user/{id}")
|
||||||
|
async def get(
|
||||||
|
user: User = UserDep,
|
||||||
|
): ...
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
crud = CrudFactory(model)
|
||||||
|
name = (
|
||||||
|
param_name
|
||||||
|
if param_name is not None
|
||||||
|
else "{}_{}".format(model.__name__.lower(), field.key)
|
||||||
|
)
|
||||||
|
python_type = field.type.python_type
|
||||||
|
|
||||||
|
async def dependency(
|
||||||
|
session: AsyncSession = Depends(session_dep), **kwargs: Any
|
||||||
|
) -> ModelType:
|
||||||
|
value = kwargs[name]
|
||||||
|
return await crud.get(session, filters=[field == value])
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
dependency,
|
||||||
|
"__signature__",
|
||||||
|
inspect.Signature(
|
||||||
|
parameters=[
|
||||||
|
inspect.Parameter(
|
||||||
|
name, inspect.Parameter.KEYWORD_ONLY, annotation=python_type
|
||||||
|
),
|
||||||
|
inspect.Parameter(
|
||||||
|
"session",
|
||||||
|
inspect.Parameter.KEYWORD_ONLY,
|
||||||
|
annotation=AsyncSession,
|
||||||
|
default=Depends(session_dep),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return cast(ModelType, Depends(cast(Callable[..., ModelType], dependency)))
|
||||||
|
|
||||||
|
|
||||||
|
def BodyDependency(
|
||||||
|
model: type[ModelType],
|
||||||
|
field: Any,
|
||||||
|
*,
|
||||||
|
session_dep: SessionDependency,
|
||||||
|
body_field: str,
|
||||||
|
) -> ModelType:
|
||||||
|
"""Create a dependency that fetches a DB object from a body field.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: SQLAlchemy model class
|
||||||
|
field: Model field to filter by (e.g., User.id)
|
||||||
|
session_dep: Session dependency function (e.g., get_db)
|
||||||
|
body_field: Name of the field in the request body
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Depends() instance that resolves to the model instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If no matching record is found
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
UserDep = BodyDependency(
|
||||||
|
User, User.ctfd_id, session_dep=get_db, body_field="user_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post("/assign")
|
||||||
|
async def assign(
|
||||||
|
user: User = UserDep,
|
||||||
|
): ...
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
crud = CrudFactory(model)
|
||||||
|
python_type = field.type.python_type
|
||||||
|
|
||||||
|
async def dependency(
|
||||||
|
session: AsyncSession = Depends(session_dep), **kwargs: Any
|
||||||
|
) -> ModelType:
|
||||||
|
value = kwargs[body_field]
|
||||||
|
return await crud.get(session, filters=[field == value])
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
dependency,
|
||||||
|
"__signature__",
|
||||||
|
inspect.Signature(
|
||||||
|
parameters=[
|
||||||
|
inspect.Parameter(
|
||||||
|
body_field, inspect.Parameter.KEYWORD_ONLY, annotation=python_type
|
||||||
|
),
|
||||||
|
inspect.Parameter(
|
||||||
|
"session",
|
||||||
|
inspect.Parameter.KEYWORD_ONLY,
|
||||||
|
annotation=AsyncSession,
|
||||||
|
default=Depends(session_dep),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return cast(ModelType, Depends(cast(Callable[..., ModelType], dependency)))
|
||||||
@@ -1,7 +1,12 @@
|
|||||||
|
"""Standardized API exceptions and error response handlers."""
|
||||||
|
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
|
ApiError,
|
||||||
ApiException,
|
ApiException,
|
||||||
ConflictError,
|
ConflictError,
|
||||||
ForbiddenError,
|
ForbiddenError,
|
||||||
|
InvalidFacetFilterError,
|
||||||
|
InvalidOrderFieldError,
|
||||||
NoSearchableFieldsError,
|
NoSearchableFieldsError,
|
||||||
NotFoundError,
|
NotFoundError,
|
||||||
UnauthorizedError,
|
UnauthorizedError,
|
||||||
@@ -10,11 +15,14 @@ from .exceptions import (
|
|||||||
from .handler import init_exceptions_handlers
|
from .handler import init_exceptions_handlers
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"init_exceptions_handlers",
|
"ApiError",
|
||||||
"generate_error_responses",
|
|
||||||
"ApiException",
|
"ApiException",
|
||||||
"ConflictError",
|
"ConflictError",
|
||||||
"ForbiddenError",
|
"ForbiddenError",
|
||||||
|
"generate_error_responses",
|
||||||
|
"init_exceptions_handlers",
|
||||||
|
"InvalidFacetFilterError",
|
||||||
|
"InvalidOrderFieldError",
|
||||||
"NoSearchableFieldsError",
|
"NoSearchableFieldsError",
|
||||||
"NotFoundError",
|
"NotFoundError",
|
||||||
"UnauthorizedError",
|
"UnauthorizedError",
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ class ApiException(Exception):
|
|||||||
The exception handler will use api_error to generate the response.
|
The exception handler will use api_error to generate the response.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
class CustomError(ApiException):
|
class CustomError(ApiException):
|
||||||
api_error = ApiError(
|
api_error = ApiError(
|
||||||
code=400,
|
code=400,
|
||||||
@@ -19,6 +20,7 @@ class ApiException(Exception):
|
|||||||
desc="The request was invalid.",
|
desc="The request was invalid.",
|
||||||
err_code="CUSTOM-400",
|
err_code="CUSTOM-400",
|
||||||
)
|
)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
api_error: ClassVar[ApiError]
|
api_error: ClassVar[ApiError]
|
||||||
@@ -76,49 +78,6 @@ class ConflictError(ApiException):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InsufficientRolesError(ForbiddenError):
|
|
||||||
"""User does not have the required roles."""
|
|
||||||
|
|
||||||
api_error = ApiError(
|
|
||||||
code=403,
|
|
||||||
msg="Insufficient Roles",
|
|
||||||
desc="You do not have the required roles to access this resource.",
|
|
||||||
err_code="RBAC-403",
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, required_roles: list[str], user_roles: set[str] | None = None):
|
|
||||||
self.required_roles = required_roles
|
|
||||||
self.user_roles = user_roles
|
|
||||||
|
|
||||||
desc = f"Required roles: {', '.join(required_roles)}"
|
|
||||||
if user_roles is not None:
|
|
||||||
desc += f". User has: {', '.join(user_roles) if user_roles else 'no roles'}"
|
|
||||||
|
|
||||||
super().__init__(desc)
|
|
||||||
|
|
||||||
|
|
||||||
class UserNotFoundError(NotFoundError):
|
|
||||||
"""User was not found."""
|
|
||||||
|
|
||||||
api_error = ApiError(
|
|
||||||
code=404,
|
|
||||||
msg="User Not Found",
|
|
||||||
desc="The requested user was not found.",
|
|
||||||
err_code="USER-404",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RoleNotFoundError(NotFoundError):
|
|
||||||
"""Role was not found."""
|
|
||||||
|
|
||||||
api_error = ApiError(
|
|
||||||
code=404,
|
|
||||||
msg="Role Not Found",
|
|
||||||
desc="The requested role was not found.",
|
|
||||||
err_code="ROLE-404",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class NoSearchableFieldsError(ApiException):
|
class NoSearchableFieldsError(ApiException):
|
||||||
"""Raised when search is requested but no searchable fields are available."""
|
"""Raised when search is requested but no searchable fields are available."""
|
||||||
|
|
||||||
@@ -130,6 +89,11 @@ class NoSearchableFieldsError(ApiException):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, model: type) -> None:
|
def __init__(self, model: type) -> None:
|
||||||
|
"""Initialize the exception.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The SQLAlchemy model class that has no searchable fields
|
||||||
|
"""
|
||||||
self.model = model
|
self.model = model
|
||||||
detail = (
|
detail = (
|
||||||
f"No searchable fields found for model '{model.__name__}'. "
|
f"No searchable fields found for model '{model.__name__}'. "
|
||||||
@@ -138,6 +102,57 @@ class NoSearchableFieldsError(ApiException):
|
|||||||
super().__init__(detail)
|
super().__init__(detail)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidFacetFilterError(ApiException):
|
||||||
|
"""Raised when filter_by contains a key not declared in facet_fields."""
|
||||||
|
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400,
|
||||||
|
msg="Invalid Facet Filter",
|
||||||
|
desc="One or more filter_by keys are not declared as facet fields.",
|
||||||
|
err_code="FACET-400",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, key: str, valid_keys: set[str]) -> None:
|
||||||
|
"""Initialize the exception.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The unknown filter key provided by the caller
|
||||||
|
valid_keys: Set of valid keys derived from the declared facet_fields
|
||||||
|
"""
|
||||||
|
self.key = key
|
||||||
|
self.valid_keys = valid_keys
|
||||||
|
detail = (
|
||||||
|
f"'{key}' is not a declared facet field. "
|
||||||
|
f"Valid keys: {sorted(valid_keys) or 'none — set facet_fields on the CRUD class'}."
|
||||||
|
)
|
||||||
|
super().__init__(detail)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidOrderFieldError(ApiException):
|
||||||
|
"""Raised when order_by contains a field not in the allowed order fields."""
|
||||||
|
|
||||||
|
api_error = ApiError(
|
||||||
|
code=422,
|
||||||
|
msg="Invalid Order Field",
|
||||||
|
desc="The requested order field is not allowed for this resource.",
|
||||||
|
err_code="SORT-422",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, field: str, valid_fields: list[str]) -> None:
|
||||||
|
"""Initialize the exception.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field: The unknown order field provided by the caller
|
||||||
|
valid_fields: List of valid field names
|
||||||
|
"""
|
||||||
|
self.field = field
|
||||||
|
self.valid_fields = valid_fields
|
||||||
|
detail = (
|
||||||
|
f"'{field}' is not an allowed order field. Valid fields: {valid_fields}."
|
||||||
|
)
|
||||||
|
super().__init__(detail)
|
||||||
|
|
||||||
|
|
||||||
def generate_error_responses(
|
def generate_error_responses(
|
||||||
*errors: type[ApiException],
|
*errors: type[ApiException],
|
||||||
) -> dict[int | str, dict[str, Any]]:
|
) -> dict[int | str, dict[str, Any]]:
|
||||||
@@ -152,6 +167,7 @@ def generate_error_responses(
|
|||||||
Dict suitable for FastAPI's responses parameter
|
Dict suitable for FastAPI's responses parameter
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
from fastapi_toolsets.exceptions import generate_error_responses, UnauthorizedError, ForbiddenError
|
from fastapi_toolsets.exceptions import generate_error_responses, UnauthorizedError, ForbiddenError
|
||||||
|
|
||||||
@app.get(
|
@app.get(
|
||||||
@@ -160,6 +176,7 @@ def generate_error_responses(
|
|||||||
)
|
)
|
||||||
async def admin_endpoint():
|
async def admin_endpoint():
|
||||||
...
|
...
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
responses: dict[int | str, dict[str, Any]] = {}
|
responses: dict[int | str, dict[str, Any]] = {}
|
||||||
|
|
||||||
@@ -172,7 +189,7 @@ def generate_error_responses(
|
|||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"example": {
|
"example": {
|
||||||
"data": None,
|
"data": api_error.data,
|
||||||
"status": ResponseStatus.FAIL.value,
|
"status": ResponseStatus.FAIL.value,
|
||||||
"message": api_error.msg,
|
"message": api_error.msg,
|
||||||
"description": api_error.desc,
|
"description": api_error.desc,
|
||||||
|
|||||||
@@ -7,11 +7,32 @@ from fastapi.exceptions import RequestValidationError, ResponseValidationError
|
|||||||
from fastapi.openapi.utils import get_openapi
|
from fastapi.openapi.utils import get_openapi
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
from ..schemas import ResponseStatus
|
from ..schemas import ErrorResponse, ResponseStatus
|
||||||
from .exceptions import ApiException
|
from .exceptions import ApiException
|
||||||
|
|
||||||
|
|
||||||
def init_exceptions_handlers(app: FastAPI) -> FastAPI:
|
def init_exceptions_handlers(app: FastAPI) -> FastAPI:
|
||||||
|
"""Register exception handlers and custom OpenAPI schema on a FastAPI app.
|
||||||
|
|
||||||
|
Installs handlers for :class:`ApiException`, validation errors, and
|
||||||
|
unhandled exceptions, and replaces the default 422 schema with a
|
||||||
|
consistent error format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI application instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The same FastAPI instance (for chaining)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
```
|
||||||
|
"""
|
||||||
_register_exception_handlers(app)
|
_register_exception_handlers(app)
|
||||||
app.openapi = lambda: _custom_openapi(app) # type: ignore[method-assign]
|
app.openapi = lambda: _custom_openapi(app) # type: ignore[method-assign]
|
||||||
return app
|
return app
|
||||||
@@ -35,16 +56,16 @@ def _register_exception_handlers(app: FastAPI) -> None:
|
|||||||
async def api_exception_handler(request: Request, exc: ApiException) -> Response:
|
async def api_exception_handler(request: Request, exc: ApiException) -> Response:
|
||||||
"""Handle custom API exceptions with structured response."""
|
"""Handle custom API exceptions with structured response."""
|
||||||
api_error = exc.api_error
|
api_error = exc.api_error
|
||||||
|
error_response = ErrorResponse(
|
||||||
|
data=api_error.data,
|
||||||
|
message=api_error.msg,
|
||||||
|
description=api_error.desc,
|
||||||
|
error_code=api_error.err_code,
|
||||||
|
)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=api_error.code,
|
status_code=api_error.code,
|
||||||
content={
|
content=error_response.model_dump(),
|
||||||
"data": None,
|
|
||||||
"status": ResponseStatus.FAIL.value,
|
|
||||||
"message": api_error.msg,
|
|
||||||
"description": api_error.desc,
|
|
||||||
"error_code": api_error.err_code,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@app.exception_handler(RequestValidationError)
|
@app.exception_handler(RequestValidationError)
|
||||||
@@ -64,15 +85,15 @@ def _register_exception_handlers(app: FastAPI) -> None:
|
|||||||
@app.exception_handler(Exception)
|
@app.exception_handler(Exception)
|
||||||
async def generic_exception_handler(request: Request, exc: Exception) -> Response:
|
async def generic_exception_handler(request: Request, exc: Exception) -> Response:
|
||||||
"""Handle all unhandled exceptions with a generic 500 response."""
|
"""Handle all unhandled exceptions with a generic 500 response."""
|
||||||
|
error_response = ErrorResponse(
|
||||||
|
message="Internal Server Error",
|
||||||
|
description="An unexpected error occurred. Please try again later.",
|
||||||
|
error_code="SERVER-500",
|
||||||
|
)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
content={
|
content=error_response.model_dump(),
|
||||||
"data": None,
|
|
||||||
"status": ResponseStatus.FAIL.value,
|
|
||||||
"message": "Internal Server Error",
|
|
||||||
"description": "An unexpected error occurred. Please try again later.",
|
|
||||||
"error_code": "SERVER-500",
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -97,15 +118,16 @@ def _format_validation_error(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
error_response = ErrorResponse(
|
||||||
|
data={"errors": formatted_errors},
|
||||||
|
message="Validation Error",
|
||||||
|
description=f"{len(formatted_errors)} validation error(s) detected",
|
||||||
|
error_code="VAL-422",
|
||||||
|
)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
content={
|
content=error_response.model_dump(),
|
||||||
"data": {"errors": formatted_errors},
|
|
||||||
"status": ResponseStatus.FAIL.value,
|
|
||||||
"message": "Validation Error",
|
|
||||||
"description": f"{len(formatted_errors)} validation error(s) detected",
|
|
||||||
"error_code": "VAL-422",
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
"""Fixture system for seeding databases with dependency resolution."""
|
||||||
|
|
||||||
from .enum import LoadStrategy
|
from .enum import LoadStrategy
|
||||||
from .registry import Context, FixtureRegistry
|
from .registry import Context, FixtureRegistry
|
||||||
from .utils import get_obj_by_attr, load_fixtures, load_fixtures_by_context
|
from .utils import get_obj_by_attr, load_fixtures, load_fixtures_by_context
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
"""Enums for fixture loading strategies and contexts."""
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
"""Fixture system with dependency management and context support."""
|
"""Fixture system with dependency management and context support."""
|
||||||
|
|
||||||
import logging
|
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from ..logger import get_logger
|
||||||
from .enum import Context
|
from .enum import Context
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -26,6 +26,7 @@ class FixtureRegistry:
|
|||||||
"""Registry for managing fixtures with dependencies.
|
"""Registry for managing fixtures with dependencies.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
from fastapi_toolsets.fixtures import FixtureRegistry, Context
|
from fastapi_toolsets.fixtures import FixtureRegistry, Context
|
||||||
|
|
||||||
fixtures = FixtureRegistry()
|
fixtures = FixtureRegistry()
|
||||||
@@ -48,10 +49,19 @@ class FixtureRegistry:
|
|||||||
return [
|
return [
|
||||||
Post(id=1, title="Test", user_id=1),
|
Post(id=1, title="Test", user_id=1),
|
||||||
]
|
]
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
contexts: list[str | Context] | None = None,
|
||||||
|
) -> None:
|
||||||
self._fixtures: dict[str, Fixture] = {}
|
self._fixtures: dict[str, Fixture] = {}
|
||||||
|
self._default_contexts: list[str] | None = (
|
||||||
|
[c.value if isinstance(c, Context) else c for c in contexts]
|
||||||
|
if contexts
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
def register(
|
def register(
|
||||||
self,
|
self,
|
||||||
@@ -72,6 +82,7 @@ class FixtureRegistry:
|
|||||||
contexts: List of contexts this fixture belongs to
|
contexts: List of contexts this fixture belongs to
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
@fixtures.register
|
@fixtures.register
|
||||||
def roles():
|
def roles():
|
||||||
return [Role(id=1, name="admin")]
|
return [Role(id=1, name="admin")]
|
||||||
@@ -79,16 +90,21 @@ class FixtureRegistry:
|
|||||||
@fixtures.register(depends_on=["roles"], contexts=[Context.TESTING])
|
@fixtures.register(depends_on=["roles"], contexts=[Context.TESTING])
|
||||||
def test_users():
|
def test_users():
|
||||||
return [User(id=1, username="test", role_id=1)]
|
return [User(id=1, username="test", role_id=1)]
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(
|
def decorator(
|
||||||
fn: Callable[[], Sequence[DeclarativeBase]],
|
fn: Callable[[], Sequence[DeclarativeBase]],
|
||||||
) -> Callable[[], Sequence[DeclarativeBase]]:
|
) -> Callable[[], Sequence[DeclarativeBase]]:
|
||||||
fixture_name = name or cast(Any, fn).__name__
|
fixture_name = name or cast(Any, fn).__name__
|
||||||
fixture_contexts = [
|
if contexts is not None:
|
||||||
c.value if isinstance(c, Context) else c
|
fixture_contexts = [
|
||||||
for c in (contexts or [Context.BASE])
|
c.value if isinstance(c, Context) else c for c in contexts
|
||||||
]
|
]
|
||||||
|
elif self._default_contexts is not None:
|
||||||
|
fixture_contexts = self._default_contexts
|
||||||
|
else:
|
||||||
|
fixture_contexts = [Context.BASE.value]
|
||||||
|
|
||||||
self._fixtures[fixture_name] = Fixture(
|
self._fixtures[fixture_name] = Fixture(
|
||||||
name=fixture_name,
|
name=fixture_name,
|
||||||
@@ -102,6 +118,34 @@ class FixtureRegistry:
|
|||||||
return decorator(func)
|
return decorator(func)
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
def include_registry(self, registry: "FixtureRegistry") -> None:
|
||||||
|
"""Include another `FixtureRegistry` in the same current `FixtureRegistry`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registry: The `FixtureRegistry` to include
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If a fixture name already exists in the current registry
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
dev_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@dev_registry.register
|
||||||
|
def dev_data():
|
||||||
|
return [...]
|
||||||
|
|
||||||
|
registry.include_registry(registry=dev_registry)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
for name, fixture in registry._fixtures.items():
|
||||||
|
if name in self._fixtures:
|
||||||
|
raise ValueError(
|
||||||
|
f"Fixture '{name}' already exists in the current registry"
|
||||||
|
)
|
||||||
|
self._fixtures[name] = fixture
|
||||||
|
|
||||||
def get(self, name: str) -> Fixture:
|
def get(self, name: str) -> Fixture:
|
||||||
"""Get a fixture by name."""
|
"""Get a fixture by name."""
|
||||||
if name not in self._fixtures:
|
if name not in self._fixtures:
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import logging
|
"""Fixture loading utilities for database seeding."""
|
||||||
|
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from typing import Any, TypeVar
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
@@ -6,10 +7,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
from ..db import get_transaction
|
from ..db import get_transaction
|
||||||
|
from ..logger import get_logger
|
||||||
from .enum import LoadStrategy
|
from .enum import LoadStrategy
|
||||||
from .registry import Context, FixtureRegistry
|
from .registry import Context, FixtureRegistry
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger()
|
||||||
|
|
||||||
T = TypeVar("T", bound=DeclarativeBase)
|
T = TypeVar("T", bound=DeclarativeBase)
|
||||||
|
|
||||||
@@ -29,9 +31,14 @@ def get_obj_by_attr(
|
|||||||
The first model instance where the attribute matches the given value.
|
The first model instance where the attribute matches the given value.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
StopIteration: If no matching object is found.
|
StopIteration: If no matching object is found in the fixture group.
|
||||||
"""
|
"""
|
||||||
return next(obj for obj in fixtures() if getattr(obj, attr_name) == value)
|
try:
|
||||||
|
return next(obj for obj in fixtures() if getattr(obj, attr_name) == value)
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration(
|
||||||
|
f"No object with {attr_name}={value} found in fixture '{getattr(fixtures, '__name__', repr(fixtures))}'"
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
async def load_fixtures(
|
async def load_fixtures(
|
||||||
@@ -52,9 +59,11 @@ async def load_fixtures(
|
|||||||
Dict mapping fixture names to loaded instances
|
Dict mapping fixture names to loaded instances
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
# Loads 'roles' first (dependency), then 'users'
|
# Loads 'roles' first (dependency), then 'users'
|
||||||
result = await load_fixtures(session, fixtures, "users")
|
result = await load_fixtures(session, fixtures, "users")
|
||||||
print(result["users"]) # [User(...), ...]
|
print(result["users"]) # [User(...), ...]
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
ordered = registry.resolve_dependencies(*names)
|
ordered = registry.resolve_dependencies(*names)
|
||||||
return await _load_ordered(session, registry, ordered, strategy)
|
return await _load_ordered(session, registry, ordered, strategy)
|
||||||
@@ -78,11 +87,13 @@ async def load_fixtures_by_context(
|
|||||||
Dict mapping fixture names to loaded instances
|
Dict mapping fixture names to loaded instances
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
# Load base + testing fixtures
|
# Load base + testing fixtures
|
||||||
await load_fixtures_by_context(
|
await load_fixtures_by_context(
|
||||||
session, fixtures,
|
session, fixtures,
|
||||||
Context.BASE, Context.TESTING
|
Context.BASE, Context.TESTING
|
||||||
)
|
)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
ordered = registry.resolve_context_dependencies(*contexts)
|
ordered = registry.resolve_context_dependencies(*contexts)
|
||||||
return await _load_ordered(session, registry, ordered, strategy)
|
return await _load_ordered(session, registry, ordered, strategy)
|
||||||
|
|||||||
98
src/fastapi_toolsets/logger.py
Normal file
98
src/fastapi_toolsets/logger.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""Logging configuration for FastAPI applications and CLI tools."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
__all__ = ["LogLevel", "configure_logging", "get_logger"]
|
||||||
|
|
||||||
|
DEFAULT_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
UVICORN_LOGGERS = ("uvicorn", "uvicorn.access", "uvicorn.error")
|
||||||
|
|
||||||
|
LogLevel = Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
||||||
|
|
||||||
|
|
||||||
|
def configure_logging(
|
||||||
|
level: LogLevel | int = "INFO",
|
||||||
|
fmt: str = DEFAULT_FORMAT,
|
||||||
|
logger_name: str | None = None,
|
||||||
|
) -> logging.Logger:
|
||||||
|
"""Configure logging with a stdout handler and consistent format.
|
||||||
|
|
||||||
|
Sets up a :class:`~logging.StreamHandler` writing to stdout with the
|
||||||
|
given format and level. Also configures the uvicorn loggers so that
|
||||||
|
FastAPI access logs use the same format.
|
||||||
|
|
||||||
|
Calling this function multiple times is safe -- existing handlers are
|
||||||
|
replaced rather than duplicated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: Log level (e.g. ``"DEBUG"``, ``"INFO"``, or ``logging.DEBUG``).
|
||||||
|
fmt: Log format string. Defaults to
|
||||||
|
``"%(asctime)s - %(name)s - %(levelname)s - %(message)s"``.
|
||||||
|
logger_name: Logger name to configure. ``None`` (the default)
|
||||||
|
configures the root logger so all loggers inherit the settings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The configured Logger instance.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.logger import configure_logging
|
||||||
|
|
||||||
|
logger = configure_logging("DEBUG")
|
||||||
|
logger.info("Application started")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
formatter = logging.Formatter(fmt)
|
||||||
|
|
||||||
|
handler = logging.StreamHandler(sys.stdout)
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
logger = logging.getLogger(logger_name)
|
||||||
|
logger.handlers.clear()
|
||||||
|
logger.addHandler(handler)
|
||||||
|
logger.setLevel(level)
|
||||||
|
|
||||||
|
for name in UVICORN_LOGGERS:
|
||||||
|
uv_logger = logging.getLogger(name)
|
||||||
|
uv_logger.handlers.clear()
|
||||||
|
uv_logger.addHandler(handler)
|
||||||
|
uv_logger.setLevel(level)
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
_SENTINEL = object()
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name: str | None = _SENTINEL) -> logging.Logger: # type: ignore[assignment]
|
||||||
|
"""Return a logger with the given *name*.
|
||||||
|
|
||||||
|
A thin convenience wrapper around :func:`logging.getLogger` that keeps
|
||||||
|
logging imports consistent across the codebase.
|
||||||
|
|
||||||
|
When called without arguments, the caller's ``__name__`` is used
|
||||||
|
automatically, so ``get_logger()`` in a module is equivalent to
|
||||||
|
``logging.getLogger(__name__)``. Pass ``None`` explicitly to get the
|
||||||
|
root logger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Logger name. Defaults to the caller's ``__name__``.
|
||||||
|
Pass ``None`` to get the root logger.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Logger instance.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger() # uses caller's __name__
|
||||||
|
logger = get_logger("myapp") # explicit name
|
||||||
|
logger = get_logger(None) # root logger
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if name is _SENTINEL:
|
||||||
|
name = sys._getframe(1).f_globals.get("__name__")
|
||||||
|
return logging.getLogger(name)
|
||||||
21
src/fastapi_toolsets/metrics/__init__.py
Normal file
21
src/fastapi_toolsets/metrics/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""Prometheus metrics integration for FastAPI applications."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .registry import Metric, MetricsRegistry
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .handler import init_metrics
|
||||||
|
except ImportError:
|
||||||
|
|
||||||
|
def init_metrics(*_args: Any, **_kwargs: Any) -> None:
|
||||||
|
from .._imports import require_extra
|
||||||
|
|
||||||
|
require_extra(package="prometheus_client", extra="metrics")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Metric",
|
||||||
|
"MetricsRegistry",
|
||||||
|
"init_metrics",
|
||||||
|
]
|
||||||
75
src/fastapi_toolsets/metrics/handler.py
Normal file
75
src/fastapi_toolsets/metrics/handler.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
"""Prometheus metrics endpoint for FastAPI applications."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.responses import Response
|
||||||
|
from prometheus_client import (
|
||||||
|
CONTENT_TYPE_LATEST,
|
||||||
|
CollectorRegistry,
|
||||||
|
generate_latest,
|
||||||
|
multiprocess,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..logger import get_logger
|
||||||
|
from .registry import MetricsRegistry
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_multiprocess() -> bool:
|
||||||
|
"""Check if prometheus multi-process mode is enabled."""
|
||||||
|
return "PROMETHEUS_MULTIPROC_DIR" in os.environ
|
||||||
|
|
||||||
|
|
||||||
|
def init_metrics(
|
||||||
|
app: FastAPI,
|
||||||
|
registry: MetricsRegistry,
|
||||||
|
*,
|
||||||
|
path: str = "/metrics",
|
||||||
|
) -> FastAPI:
|
||||||
|
"""Register a Prometheus ``/metrics`` endpoint on a FastAPI app.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI application instance.
|
||||||
|
registry: A :class:`MetricsRegistry` containing providers and collectors.
|
||||||
|
path: URL path for the metrics endpoint (default ``/metrics``).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The same FastAPI instance (for chaining).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi_toolsets.metrics import MetricsRegistry, init_metrics
|
||||||
|
|
||||||
|
metrics = MetricsRegistry()
|
||||||
|
app = FastAPI()
|
||||||
|
init_metrics(app, registry=metrics)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
for provider in registry.get_providers():
|
||||||
|
logger.debug("Initialising metric provider '%s'", provider.name)
|
||||||
|
provider.func()
|
||||||
|
|
||||||
|
collectors = registry.get_collectors()
|
||||||
|
|
||||||
|
@app.get(path, include_in_schema=False)
|
||||||
|
async def metrics_endpoint() -> Response:
|
||||||
|
for collector in collectors:
|
||||||
|
if asyncio.iscoroutinefunction(collector.func):
|
||||||
|
await collector.func()
|
||||||
|
else:
|
||||||
|
collector.func()
|
||||||
|
|
||||||
|
if _is_multiprocess():
|
||||||
|
prom_registry = CollectorRegistry()
|
||||||
|
multiprocess.MultiProcessCollector(prom_registry)
|
||||||
|
output = generate_latest(prom_registry)
|
||||||
|
else:
|
||||||
|
output = generate_latest()
|
||||||
|
|
||||||
|
return Response(content=output, media_type=CONTENT_TYPE_LATEST)
|
||||||
|
|
||||||
|
return app
|
||||||
128
src/fastapi_toolsets/metrics/registry.py
Normal file
128
src/fastapi_toolsets/metrics/registry.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
"""Metrics registry with decorator-based registration."""
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from ..logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Metric:
|
||||||
|
"""A metric definition with metadata."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
func: Callable[..., Any]
|
||||||
|
collect: bool = field(default=False)
|
||||||
|
|
||||||
|
|
||||||
|
class MetricsRegistry:
|
||||||
|
"""Registry for managing Prometheus metric providers and collectors.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from prometheus_client import Counter, Gauge
|
||||||
|
from fastapi_toolsets.metrics import MetricsRegistry
|
||||||
|
|
||||||
|
metrics = MetricsRegistry()
|
||||||
|
|
||||||
|
@metrics.register
|
||||||
|
def http_requests():
|
||||||
|
return Counter("http_requests_total", "Total HTTP requests", ["method", "status"])
|
||||||
|
|
||||||
|
@metrics.register(name="db_pool")
|
||||||
|
def database_pool_size():
|
||||||
|
return Gauge("db_pool_size", "Database connection pool size")
|
||||||
|
|
||||||
|
@metrics.register(collect=True)
|
||||||
|
def collect_queue_depth(gauge=Gauge("queue_depth", "Current queue depth")):
|
||||||
|
gauge.set(get_current_queue_depth())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._metrics: dict[str, Metric] = {}
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self,
|
||||||
|
func: Callable[..., Any] | None = None,
|
||||||
|
*,
|
||||||
|
name: str | None = None,
|
||||||
|
collect: bool = False,
|
||||||
|
) -> Callable[..., Any]:
|
||||||
|
"""Register a metric provider or collector function.
|
||||||
|
|
||||||
|
Can be used as a decorator with or without arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The metric function to register.
|
||||||
|
name: Metric name (defaults to function name).
|
||||||
|
collect: If ``True``, the function is called on every scrape.
|
||||||
|
If ``False`` (default), called once at init time.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
@metrics.register
|
||||||
|
def my_counter():
|
||||||
|
return Counter("my_counter", "A counter")
|
||||||
|
|
||||||
|
@metrics.register(collect=True, name="queue")
|
||||||
|
def collect_queue_depth():
|
||||||
|
gauge.set(compute_depth())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
|
metric_name = name or cast(Any, fn).__name__
|
||||||
|
self._metrics[metric_name] = Metric(
|
||||||
|
name=metric_name,
|
||||||
|
func=fn,
|
||||||
|
collect=collect,
|
||||||
|
)
|
||||||
|
return fn
|
||||||
|
|
||||||
|
if func is not None:
|
||||||
|
return decorator(func)
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def include_registry(self, registry: "MetricsRegistry") -> None:
|
||||||
|
"""Include another :class:`MetricsRegistry` into this one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registry: The registry to merge in.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If a metric name already exists in the current registry.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
main = MetricsRegistry()
|
||||||
|
sub = MetricsRegistry()
|
||||||
|
|
||||||
|
@sub.register
|
||||||
|
def sub_metric():
|
||||||
|
return Counter("sub_total", "Sub counter")
|
||||||
|
|
||||||
|
main.include_registry(sub)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
for metric_name, definition in registry._metrics.items():
|
||||||
|
if metric_name in self._metrics:
|
||||||
|
raise ValueError(
|
||||||
|
f"Metric '{metric_name}' already exists in the current registry"
|
||||||
|
)
|
||||||
|
self._metrics[metric_name] = definition
|
||||||
|
|
||||||
|
def get_all(self) -> list[Metric]:
|
||||||
|
"""Get all registered metric definitions."""
|
||||||
|
return list(self._metrics.values())
|
||||||
|
|
||||||
|
def get_providers(self) -> list[Metric]:
|
||||||
|
"""Get metric providers (called once at init)."""
|
||||||
|
return [m for m in self._metrics.values() if not m.collect]
|
||||||
|
|
||||||
|
def get_collectors(self) -> list[Metric]:
|
||||||
|
"""Get collectors (called on each scrape)."""
|
||||||
|
return [m for m in self._metrics.values() if m.collect]
|
||||||
@@ -1,8 +1,30 @@
|
|||||||
from .plugin import register_fixtures
|
"""Pytest helpers for FastAPI testing: sessions, clients, and fixtures."""
|
||||||
from .utils import create_async_client, create_db_session
|
|
||||||
|
try:
|
||||||
|
from .plugin import register_fixtures
|
||||||
|
except ImportError:
|
||||||
|
from .._imports import require_extra
|
||||||
|
|
||||||
|
require_extra(package="pytest", extra="pytest")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .utils import (
|
||||||
|
cleanup_tables,
|
||||||
|
create_async_client,
|
||||||
|
create_db_session,
|
||||||
|
create_worker_database,
|
||||||
|
worker_database_url,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
from .._imports import require_extra
|
||||||
|
|
||||||
|
require_extra(package="httpx", extra="pytest")
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"cleanup_tables",
|
||||||
"create_async_client",
|
"create_async_client",
|
||||||
"create_db_session",
|
"create_db_session",
|
||||||
|
"create_worker_database",
|
||||||
"register_fixtures",
|
"register_fixtures",
|
||||||
|
"worker_database_url",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,55 +1,4 @@
|
|||||||
"""Pytest plugin for using FixtureRegistry fixtures in tests.
|
"""Pytest plugin for using FixtureRegistry fixtures in tests."""
|
||||||
|
|
||||||
This module provides utilities to automatically generate pytest fixtures
|
|
||||||
from your FixtureRegistry, with proper dependency resolution.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# conftest.py
|
|
||||||
import pytest
|
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
|
||||||
|
|
||||||
from app.fixtures import fixtures # Your FixtureRegistry
|
|
||||||
from app.models import Base
|
|
||||||
from fastapi_toolsets.pytest_plugin import register_fixtures
|
|
||||||
|
|
||||||
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/test_db"
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def engine():
|
|
||||||
engine = create_async_engine(DATABASE_URL)
|
|
||||||
yield engine
|
|
||||||
await engine.dispose()
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def db_session(engine):
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
|
||||||
|
|
||||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
||||||
session = session_factory()
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield session
|
|
||||||
finally:
|
|
||||||
await session.close()
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
await conn.run_sync(Base.metadata.drop_all)
|
|
||||||
|
|
||||||
# Automatically generate pytest fixtures from registry
|
|
||||||
# Creates: fixture_roles, fixture_users, fixture_posts, etc.
|
|
||||||
register_fixtures(fixtures, globals())
|
|
||||||
|
|
||||||
Usage in tests:
|
|
||||||
# test_users.py
|
|
||||||
async def test_user_count(db_session, fixture_users):
|
|
||||||
# fixture_users automatically loads fixture_roles first (if dependency)
|
|
||||||
# and returns the list of User models
|
|
||||||
assert len(fixture_users) > 0
|
|
||||||
|
|
||||||
async def test_user_role(db_session, fixture_users):
|
|
||||||
user = fixture_users[0]
|
|
||||||
assert user.role_id is not None
|
|
||||||
"""
|
|
||||||
|
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -86,6 +35,7 @@ def register_fixtures(
|
|||||||
List of created fixture names
|
List of created fixture names
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
# conftest.py
|
# conftest.py
|
||||||
from app.fixtures import fixtures
|
from app.fixtures import fixtures
|
||||||
from fastapi_toolsets.pytest_plugin import register_fixtures
|
from fastapi_toolsets.pytest_plugin import register_fixtures
|
||||||
@@ -96,6 +46,7 @@ def register_fixtures(
|
|||||||
# - fixture_roles
|
# - fixture_roles
|
||||||
# - fixture_users (depends on fixture_roles if users depends on roles)
|
# - fixture_users (depends on fixture_roles if users depends on roles)
|
||||||
# - fixture_posts (depends on fixture_users if posts depends on users)
|
# - fixture_posts (depends on fixture_users if posts depends on users)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
created_fixtures: list[str] = []
|
created_fixtures: list[str] = []
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,18 @@
|
|||||||
"""Pytest helper utilities for FastAPI testing."""
|
"""Pytest helper utilities for FastAPI testing."""
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator
|
import os
|
||||||
|
from collections.abc import AsyncGenerator, Callable
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from httpx import ASGITransport, AsyncClient
|
from httpx import ASGITransport, AsyncClient
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.engine import make_url
|
||||||
|
from sqlalchemy.ext.asyncio import (
|
||||||
|
AsyncSession,
|
||||||
|
async_sessionmaker,
|
||||||
|
create_async_engine,
|
||||||
|
)
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
from ..db import create_db_context
|
from ..db import create_db_context
|
||||||
@@ -15,12 +22,16 @@ from ..db import create_db_context
|
|||||||
async def create_async_client(
|
async def create_async_client(
|
||||||
app: Any,
|
app: Any,
|
||||||
base_url: str = "http://test",
|
base_url: str = "http://test",
|
||||||
|
dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] | None = None,
|
||||||
) -> AsyncGenerator[AsyncClient, None]:
|
) -> AsyncGenerator[AsyncClient, None]:
|
||||||
"""Create an async httpx client for testing FastAPI applications.
|
"""Create an async httpx client for testing FastAPI applications.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
app: FastAPI application instance.
|
app: FastAPI application instance.
|
||||||
base_url: Base URL for requests. Defaults to "http://test".
|
base_url: Base URL for requests. Defaults to "http://test".
|
||||||
|
dependency_overrides: Optional mapping of original dependencies to
|
||||||
|
their test replacements. Applied via ``app.dependency_overrides``
|
||||||
|
before yielding and cleaned up after.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
An AsyncClient configured for the app.
|
An AsyncClient configured for the app.
|
||||||
@@ -41,10 +52,39 @@ async def create_async_client(
|
|||||||
response = await client.get("/health")
|
response = await client.get("/health")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Example with dependency overrides:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import create_async_client, create_db_session
|
||||||
|
from app.db import get_db
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session():
|
||||||
|
async with create_db_session(DATABASE_URL, Base, cleanup=True) as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def client(db_session):
|
||||||
|
async def override():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
async with create_async_client(
|
||||||
|
app, dependency_overrides={get_db: override}
|
||||||
|
) as c:
|
||||||
|
yield c
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
if dependency_overrides:
|
||||||
|
app.dependency_overrides.update(dependency_overrides)
|
||||||
|
|
||||||
transport = ASGITransport(app=app)
|
transport = ASGITransport(app=app)
|
||||||
async with AsyncClient(transport=transport, base_url=base_url) as client:
|
try:
|
||||||
yield client
|
async with AsyncClient(transport=transport, base_url=base_url) as client:
|
||||||
|
yield client
|
||||||
|
finally:
|
||||||
|
if dependency_overrides:
|
||||||
|
for key in dependency_overrides:
|
||||||
|
app.dependency_overrides.pop(key, None)
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@@ -55,6 +95,7 @@ async def create_db_session(
|
|||||||
echo: bool = False,
|
echo: bool = False,
|
||||||
expire_on_commit: bool = False,
|
expire_on_commit: bool = False,
|
||||||
drop_tables: bool = True,
|
drop_tables: bool = True,
|
||||||
|
cleanup: bool = False,
|
||||||
) -> AsyncGenerator[AsyncSession, None]:
|
) -> AsyncGenerator[AsyncSession, None]:
|
||||||
"""Create a database session for testing.
|
"""Create a database session for testing.
|
||||||
|
|
||||||
@@ -67,6 +108,8 @@ async def create_db_session(
|
|||||||
echo: Enable SQLAlchemy query logging. Defaults to False.
|
echo: Enable SQLAlchemy query logging. Defaults to False.
|
||||||
expire_on_commit: Expire objects after commit. Defaults to False.
|
expire_on_commit: Expire objects after commit. Defaults to False.
|
||||||
drop_tables: Drop tables after test. Defaults to True.
|
drop_tables: Drop tables after test. Defaults to True.
|
||||||
|
cleanup: Truncate all tables after test using
|
||||||
|
:func:`cleanup_tables`. Defaults to False.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
An AsyncSession ready for database operations.
|
An AsyncSession ready for database operations.
|
||||||
@@ -80,7 +123,9 @@ async def create_db_session(
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def db_session():
|
async def db_session():
|
||||||
async with create_db_session(DATABASE_URL, Base) as session:
|
async with create_db_session(
|
||||||
|
DATABASE_URL, Base, cleanup=True
|
||||||
|
) as session:
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
async def test_create_user(db_session: AsyncSession):
|
async def test_create_user(db_session: AsyncSession):
|
||||||
@@ -103,8 +148,168 @@ async def create_db_session(
|
|||||||
async with get_session() as session:
|
async with get_session() as session:
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
if cleanup:
|
||||||
|
await cleanup_tables(session, base)
|
||||||
|
|
||||||
if drop_tables:
|
if drop_tables:
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.run_sync(base.metadata.drop_all)
|
await conn.run_sync(base.metadata.drop_all)
|
||||||
finally:
|
finally:
|
||||||
await engine.dispose()
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_xdist_worker(default_test_db: str) -> str:
|
||||||
|
"""Return the pytest-xdist worker name, or *default_test_db* when not running under xdist.
|
||||||
|
|
||||||
|
Reads the ``PYTEST_XDIST_WORKER`` environment variable that xdist sets
|
||||||
|
automatically in each worker process (e.g. ``"gw0"``, ``"gw1"``).
|
||||||
|
When xdist is not installed or not active, the variable is absent and
|
||||||
|
*default_test_db* is returned instead.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
default_test_db: Fallback value returned when ``PYTEST_XDIST_WORKER``
|
||||||
|
is not set.
|
||||||
|
"""
|
||||||
|
return os.environ.get("PYTEST_XDIST_WORKER", default_test_db)
|
||||||
|
|
||||||
|
|
||||||
|
def worker_database_url(database_url: str, default_test_db: str) -> str:
|
||||||
|
"""Derive a per-worker database URL for pytest-xdist parallel runs.
|
||||||
|
|
||||||
|
Appends ``_{worker_name}`` to the database name so each xdist worker
|
||||||
|
operates on its own database. When not running under xdist,
|
||||||
|
``_{default_test_db}`` is appended instead.
|
||||||
|
|
||||||
|
The worker name is read from the ``PYTEST_XDIST_WORKER`` environment
|
||||||
|
variable (set automatically by xdist in each worker process).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
database_url: Original database connection URL.
|
||||||
|
default_test_db: Suffix appended to the database name when
|
||||||
|
``PYTEST_XDIST_WORKER`` is not set.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A database URL with a worker- or default-specific database name.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
# With PYTEST_XDIST_WORKER="gw0":
|
||||||
|
url = worker_database_url(
|
||||||
|
"postgresql+asyncpg://user:pass@localhost/test_db",
|
||||||
|
default_test_db="test",
|
||||||
|
)
|
||||||
|
# "postgresql+asyncpg://user:pass@localhost/test_db_gw0"
|
||||||
|
|
||||||
|
# Without PYTEST_XDIST_WORKER:
|
||||||
|
url = worker_database_url(
|
||||||
|
"postgresql+asyncpg://user:pass@localhost/test_db",
|
||||||
|
default_test_db="test",
|
||||||
|
)
|
||||||
|
# "postgresql+asyncpg://user:pass@localhost/test_db_test"
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
worker = _get_xdist_worker(default_test_db=default_test_db)
|
||||||
|
|
||||||
|
url = make_url(database_url)
|
||||||
|
url = url.set(database=f"{url.database}_{worker}")
|
||||||
|
return url.render_as_string(hide_password=False)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def create_worker_database(
|
||||||
|
database_url: str,
|
||||||
|
default_test_db: str = "test_db",
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Create and drop a per-worker database for pytest-xdist isolation.
|
||||||
|
|
||||||
|
Intended for use as a **session-scoped** fixture. Connects to the server
|
||||||
|
using the original *database_url* (with ``AUTOCOMMIT`` isolation for DDL),
|
||||||
|
creates a dedicated database for the worker, and yields the worker-specific
|
||||||
|
URL. On cleanup the worker database is dropped.
|
||||||
|
|
||||||
|
When running under xdist the database name is suffixed with the worker
|
||||||
|
name (e.g. ``_gw0``). Otherwise it is suffixed with *default_test_db*.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
database_url: Original database connection URL.
|
||||||
|
default_test_db: Suffix appended to the database name when
|
||||||
|
``PYTEST_XDIST_WORKER`` is not set. Defaults to ``"test_db"``.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The worker-specific database URL.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import (
|
||||||
|
create_worker_database, create_db_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost/test_db"
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
async def worker_db_url():
|
||||||
|
async with create_worker_database(DATABASE_URL) as url:
|
||||||
|
yield url
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session(worker_db_url):
|
||||||
|
async with create_db_session(
|
||||||
|
worker_db_url, Base, cleanup=True
|
||||||
|
) as session:
|
||||||
|
yield session
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
worker_url = worker_database_url(
|
||||||
|
database_url=database_url, default_test_db=default_test_db
|
||||||
|
)
|
||||||
|
worker_db_name = make_url(worker_url).database
|
||||||
|
|
||||||
|
engine = create_async_engine(
|
||||||
|
database_url,
|
||||||
|
isolation_level="AUTOCOMMIT",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
await conn.execute(text(f"DROP DATABASE IF EXISTS {worker_db_name}"))
|
||||||
|
await conn.execute(text(f"CREATE DATABASE {worker_db_name}"))
|
||||||
|
|
||||||
|
yield worker_url
|
||||||
|
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
await conn.execute(text(f"DROP DATABASE IF EXISTS {worker_db_name}"))
|
||||||
|
finally:
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
async def cleanup_tables(
|
||||||
|
session: AsyncSession,
|
||||||
|
base: type[DeclarativeBase],
|
||||||
|
) -> None:
|
||||||
|
"""Truncate all tables for fast between-test cleanup.
|
||||||
|
|
||||||
|
Executes a single ``TRUNCATE … RESTART IDENTITY CASCADE`` statement
|
||||||
|
across every table in *base*'s metadata, which is significantly faster
|
||||||
|
than dropping and re-creating tables between tests.
|
||||||
|
|
||||||
|
This is a no-op when the metadata contains no tables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: An active async database session.
|
||||||
|
base: SQLAlchemy DeclarativeBase class containing model metadata.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session(worker_db_url):
|
||||||
|
async with create_db_session(worker_db_url, Base) as session:
|
||||||
|
yield session
|
||||||
|
await cleanup_tables(session, Base)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
tables = base.metadata.sorted_tables
|
||||||
|
if not tables:
|
||||||
|
return
|
||||||
|
|
||||||
|
table_names = ", ".join(f'"{t.name}"' for t in tables)
|
||||||
|
await session.execute(text(f"TRUNCATE {table_names} RESTART IDENTITY CASCADE"))
|
||||||
|
await session.commit()
|
||||||
|
|||||||
@@ -1,15 +1,18 @@
|
|||||||
"""Base Pydantic schemas for API responses."""
|
"""Base Pydantic schemas for API responses."""
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import ClassVar, Generic, TypeVar
|
from typing import Any, ClassVar, Generic, TypeVar
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ApiError",
|
"ApiError",
|
||||||
|
"CursorPagination",
|
||||||
"ErrorResponse",
|
"ErrorResponse",
|
||||||
|
"OffsetPagination",
|
||||||
"Pagination",
|
"Pagination",
|
||||||
"PaginatedResponse",
|
"PaginatedResponse",
|
||||||
|
"PydanticBase",
|
||||||
"Response",
|
"Response",
|
||||||
"ResponseStatus",
|
"ResponseStatus",
|
||||||
]
|
]
|
||||||
@@ -49,6 +52,7 @@ class ApiError(PydanticBase):
|
|||||||
msg: str
|
msg: str
|
||||||
desc: str
|
desc: str
|
||||||
err_code: str
|
err_code: str
|
||||||
|
data: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
class BaseResponse(PydanticBase):
|
class BaseResponse(PydanticBase):
|
||||||
@@ -69,7 +73,9 @@ class Response(BaseResponse, Generic[DataT]):
|
|||||||
"""Generic API response with data payload.
|
"""Generic API response with data payload.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
Response[UserRead](data=user, message="User retrieved")
|
Response[UserRead](data=user, message="User retrieved")
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: DataT | None = None
|
data: DataT | None = None
|
||||||
@@ -83,11 +89,11 @@ class ErrorResponse(BaseResponse):
|
|||||||
|
|
||||||
status: ResponseStatus = ResponseStatus.FAIL
|
status: ResponseStatus = ResponseStatus.FAIL
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
data: None = None
|
data: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
class Pagination(PydanticBase):
|
class OffsetPagination(PydanticBase):
|
||||||
"""Pagination metadata for list responses.
|
"""Pagination metadata for offset-based list responses.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
total_count: Total number of items across all pages
|
total_count: Total number of items across all pages
|
||||||
@@ -102,15 +108,29 @@ class Pagination(PydanticBase):
|
|||||||
has_more: bool
|
has_more: bool
|
||||||
|
|
||||||
|
|
||||||
class PaginatedResponse(BaseResponse, Generic[DataT]):
|
# Backward-compatible - will be removed in v2.0
|
||||||
"""Paginated API response for list endpoints.
|
Pagination = OffsetPagination
|
||||||
|
|
||||||
Example:
|
|
||||||
PaginatedResponse[UserRead](
|
class CursorPagination(PydanticBase):
|
||||||
data=users,
|
"""Pagination metadata for cursor-based list responses.
|
||||||
pagination=Pagination(total_count=100, items_per_page=10, page=1, has_more=True)
|
|
||||||
)
|
Attributes:
|
||||||
|
next_cursor: Encoded cursor for the next page, or None on the last page.
|
||||||
|
prev_cursor: Encoded cursor for the previous page, or None on the first page.
|
||||||
|
items_per_page: Number of items requested per page.
|
||||||
|
has_more: Whether there is at least one more page after this one.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
next_cursor: str | None
|
||||||
|
prev_cursor: str | None = None
|
||||||
|
items_per_page: int
|
||||||
|
has_more: bool
|
||||||
|
|
||||||
|
|
||||||
|
class PaginatedResponse(BaseResponse, Generic[DataT]):
|
||||||
|
"""Paginated API response for list endpoints."""
|
||||||
|
|
||||||
data: list[DataT]
|
data: list[DataT]
|
||||||
pagination: Pagination
|
pagination: OffsetPagination | CursorPagination
|
||||||
|
filter_attributes: dict[str, list[Any]] | None = None
|
||||||
|
|||||||
@@ -1,27 +1,36 @@
|
|||||||
"""Shared pytest fixtures for fastapi-utils tests."""
|
"""Shared pytest fixtures for fastapi-utils tests."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import ForeignKey, String
|
import datetime
|
||||||
|
import decimal
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
Column,
|
||||||
|
Date,
|
||||||
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
Integer,
|
||||||
|
Numeric,
|
||||||
|
String,
|
||||||
|
Table,
|
||||||
|
Uuid,
|
||||||
|
)
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from fastapi_toolsets.crud import CrudFactory
|
from fastapi_toolsets.crud import CrudFactory
|
||||||
|
from fastapi_toolsets.schemas import PydanticBase
|
||||||
|
|
||||||
# PostgreSQL connection URL from environment or default for local development
|
DATABASE_URL = os.getenv(
|
||||||
DATABASE_URL = os.getenv("DATABASE_URL") or os.getenv(
|
key="DATABASE_URL",
|
||||||
"TEST_DATABASE_URL",
|
default="postgresql+asyncpg://postgres:postgres@localhost:5432/postgres",
|
||||||
"postgresql+asyncpg://postgres:postgres@localhost:5432/fastapi_toolsets_test",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Test Models
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class Base(DeclarativeBase):
|
class Base(DeclarativeBase):
|
||||||
"""Base class for test models."""
|
"""Base class for test models."""
|
||||||
|
|
||||||
@@ -33,7 +42,7 @@ class Role(Base):
|
|||||||
|
|
||||||
__tablename__ = "roles"
|
__tablename__ = "roles"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True)
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
name: Mapped[str] = mapped_column(String(50), unique=True)
|
name: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
|
|
||||||
users: Mapped[list["User"]] = relationship(back_populates="role")
|
users: Mapped[list["User"]] = relationship(back_populates="role")
|
||||||
@@ -44,36 +53,91 @@ class User(Base):
|
|||||||
|
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True)
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
username: Mapped[str] = mapped_column(String(50), unique=True)
|
username: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
email: Mapped[str] = mapped_column(String(100), unique=True)
|
email: Mapped[str] = mapped_column(String(100), unique=True)
|
||||||
is_active: Mapped[bool] = mapped_column(default=True)
|
is_active: Mapped[bool] = mapped_column(default=True)
|
||||||
role_id: Mapped[int | None] = mapped_column(ForeignKey("roles.id"), nullable=True)
|
role_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
|
ForeignKey("roles.id"), nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
role: Mapped[Role | None] = relationship(back_populates="users")
|
role: Mapped[Role | None] = relationship(back_populates="users")
|
||||||
|
|
||||||
|
|
||||||
|
class Tag(Base):
|
||||||
|
"""Test tag model."""
|
||||||
|
|
||||||
|
__tablename__ = "tags"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
|
name: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
|
|
||||||
|
|
||||||
|
post_tags = Table(
|
||||||
|
"post_tags",
|
||||||
|
Base.metadata,
|
||||||
|
Column(
|
||||||
|
"post_id", Uuid, ForeignKey("posts.id", ondelete="CASCADE"), primary_key=True
|
||||||
|
),
|
||||||
|
Column("tag_id", Uuid, ForeignKey("tags.id", ondelete="CASCADE"), primary_key=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IntRole(Base):
|
||||||
|
"""Test role model with auto-increment integer PK."""
|
||||||
|
|
||||||
|
__tablename__ = "int_roles"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
name: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
|
|
||||||
|
|
||||||
|
class Event(Base):
|
||||||
|
"""Test model with DateTime and Date cursor columns."""
|
||||||
|
|
||||||
|
__tablename__ = "events"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
|
name: Mapped[str] = mapped_column(String(100))
|
||||||
|
occurred_at: Mapped[datetime.datetime] = mapped_column(DateTime)
|
||||||
|
scheduled_date: Mapped[datetime.date] = mapped_column(Date)
|
||||||
|
|
||||||
|
|
||||||
|
class Product(Base):
|
||||||
|
"""Test model with Numeric cursor column."""
|
||||||
|
|
||||||
|
__tablename__ = "products"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
|
name: Mapped[str] = mapped_column(String(100))
|
||||||
|
price: Mapped[decimal.Decimal] = mapped_column(Numeric(10, 2))
|
||||||
|
|
||||||
|
|
||||||
class Post(Base):
|
class Post(Base):
|
||||||
"""Test post model."""
|
"""Test post model."""
|
||||||
|
|
||||||
__tablename__ = "posts"
|
__tablename__ = "posts"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True)
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
title: Mapped[str] = mapped_column(String(200))
|
title: Mapped[str] = mapped_column(String(200))
|
||||||
content: Mapped[str] = mapped_column(String(1000), default="")
|
content: Mapped[str] = mapped_column(String(1000), default="")
|
||||||
is_published: Mapped[bool] = mapped_column(default=False)
|
is_published: Mapped[bool] = mapped_column(default=False)
|
||||||
author_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
|
author_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
|
||||||
|
|
||||||
|
tags: Mapped[list[Tag]] = relationship(secondary=post_tags)
|
||||||
# =============================================================================
|
|
||||||
# Test Schemas
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class RoleCreate(BaseModel):
|
class RoleCreate(BaseModel):
|
||||||
"""Schema for creating a role."""
|
"""Schema for creating a role."""
|
||||||
|
|
||||||
id: int | None = None
|
id: uuid.UUID | None = None
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class RoleRead(PydanticBase):
|
||||||
|
"""Schema for reading a role."""
|
||||||
|
|
||||||
|
id: uuid.UUID
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
@@ -86,11 +150,18 @@ class RoleUpdate(BaseModel):
|
|||||||
class UserCreate(BaseModel):
|
class UserCreate(BaseModel):
|
||||||
"""Schema for creating a user."""
|
"""Schema for creating a user."""
|
||||||
|
|
||||||
id: int | None = None
|
id: uuid.UUID | None = None
|
||||||
username: str
|
username: str
|
||||||
email: str
|
email: str
|
||||||
is_active: bool = True
|
is_active: bool = True
|
||||||
role_id: int | None = None
|
role_id: uuid.UUID | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class UserRead(PydanticBase):
|
||||||
|
"""Schema for reading a user (subset of fields — no email)."""
|
||||||
|
|
||||||
|
id: uuid.UUID
|
||||||
|
username: str
|
||||||
|
|
||||||
|
|
||||||
class UserUpdate(BaseModel):
|
class UserUpdate(BaseModel):
|
||||||
@@ -99,17 +170,24 @@ class UserUpdate(BaseModel):
|
|||||||
username: str | None = None
|
username: str | None = None
|
||||||
email: str | None = None
|
email: str | None = None
|
||||||
is_active: bool | None = None
|
is_active: bool | None = None
|
||||||
role_id: int | None = None
|
role_id: uuid.UUID | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TagCreate(BaseModel):
|
||||||
|
"""Schema for creating a tag."""
|
||||||
|
|
||||||
|
id: uuid.UUID | None = None
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
class PostCreate(BaseModel):
|
class PostCreate(BaseModel):
|
||||||
"""Schema for creating a post."""
|
"""Schema for creating a post."""
|
||||||
|
|
||||||
id: int | None = None
|
id: uuid.UUID | None = None
|
||||||
title: str
|
title: str
|
||||||
content: str = ""
|
content: str = ""
|
||||||
is_published: bool = False
|
is_published: bool = False
|
||||||
author_id: int
|
author_id: uuid.UUID
|
||||||
|
|
||||||
|
|
||||||
class PostUpdate(BaseModel):
|
class PostUpdate(BaseModel):
|
||||||
@@ -120,18 +198,60 @@ class PostUpdate(BaseModel):
|
|||||||
is_published: bool | None = None
|
is_published: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
class PostM2MCreate(BaseModel):
|
||||||
# CRUD Classes
|
"""Schema for creating a post with M2M tag IDs."""
|
||||||
# =============================================================================
|
|
||||||
|
id: uuid.UUID | None = None
|
||||||
|
title: str
|
||||||
|
content: str = ""
|
||||||
|
is_published: bool = False
|
||||||
|
author_id: uuid.UUID
|
||||||
|
tag_ids: list[uuid.UUID] = []
|
||||||
|
|
||||||
|
|
||||||
|
class PostM2MUpdate(BaseModel):
|
||||||
|
"""Schema for updating a post with M2M tag IDs."""
|
||||||
|
|
||||||
|
title: str | None = None
|
||||||
|
content: str | None = None
|
||||||
|
is_published: bool | None = None
|
||||||
|
tag_ids: list[uuid.UUID] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class IntRoleCreate(BaseModel):
|
||||||
|
"""Schema for creating an IntRole."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class EventCreate(BaseModel):
|
||||||
|
"""Schema for creating an Event."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
occurred_at: datetime.datetime
|
||||||
|
scheduled_date: datetime.date
|
||||||
|
|
||||||
|
|
||||||
|
class ProductCreate(BaseModel):
|
||||||
|
"""Schema for creating a Product."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
price: decimal.Decimal
|
||||||
|
|
||||||
|
|
||||||
RoleCrud = CrudFactory(Role)
|
RoleCrud = CrudFactory(Role)
|
||||||
|
RoleCursorCrud = CrudFactory(Role, cursor_column=Role.id)
|
||||||
|
IntRoleCursorCrud = CrudFactory(IntRole, cursor_column=IntRole.id)
|
||||||
UserCrud = CrudFactory(User)
|
UserCrud = CrudFactory(User)
|
||||||
|
UserCursorCrud = CrudFactory(User, cursor_column=User.id)
|
||||||
PostCrud = CrudFactory(Post)
|
PostCrud = CrudFactory(Post)
|
||||||
|
TagCrud = CrudFactory(Tag)
|
||||||
|
PostM2MCrud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags})
|
||||||
# =============================================================================
|
EventCrud = CrudFactory(Event)
|
||||||
# Fixtures
|
EventDateTimeCursorCrud = CrudFactory(Event, cursor_column=Event.occurred_at)
|
||||||
# =============================================================================
|
EventDateCursorCrud = CrudFactory(Event, cursor_column=Event.scheduled_date)
|
||||||
|
ProductCrud = CrudFactory(Product)
|
||||||
|
ProductNumericCursorCrud = CrudFactory(Product, cursor_column=Product.price)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -195,5 +315,5 @@ def sample_post_data() -> PostCreate:
|
|||||||
title="Test Post",
|
title="Test Post",
|
||||||
content="Test content",
|
content="Test content",
|
||||||
is_published=True,
|
is_published=True,
|
||||||
author_id=1,
|
author_id=uuid.uuid4(),
|
||||||
)
|
)
|
||||||
|
|||||||
543
tests/test_cli.py
Normal file
543
tests/test_cli.py
Normal file
@@ -0,0 +1,543 @@
|
|||||||
|
"""Tests for fastapi_toolsets.cli module."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli.config import (
|
||||||
|
get_config_value,
|
||||||
|
get_custom_cli,
|
||||||
|
get_db_context,
|
||||||
|
get_fixtures_registry,
|
||||||
|
import_from_string,
|
||||||
|
)
|
||||||
|
from fastapi_toolsets.cli.pyproject import find_pyproject, load_pyproject
|
||||||
|
from fastapi_toolsets.cli.utils import async_command
|
||||||
|
from fastapi_toolsets.fixtures import FixtureRegistry
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
|
||||||
|
|
||||||
|
class TestPyproject:
|
||||||
|
"""Tests for pyproject.toml discovery and loading."""
|
||||||
|
|
||||||
|
def test_find_pyproject_in_current_dir(self, tmp_path, monkeypatch):
|
||||||
|
"""Finds pyproject.toml in current directory."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[project]\nname = 'test'\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = find_pyproject()
|
||||||
|
assert result == pyproject
|
||||||
|
|
||||||
|
def test_find_pyproject_in_parent_dir(self, tmp_path, monkeypatch):
|
||||||
|
"""Finds pyproject.toml in parent directory."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[project]\nname = 'test'\n")
|
||||||
|
subdir = tmp_path / "src" / "app"
|
||||||
|
subdir.mkdir(parents=True)
|
||||||
|
monkeypatch.chdir(subdir)
|
||||||
|
|
||||||
|
result = find_pyproject()
|
||||||
|
assert result == pyproject
|
||||||
|
|
||||||
|
def test_find_pyproject_not_found(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns None when no pyproject.toml exists."""
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
result = find_pyproject()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_load_pyproject_returns_tool_config(self, tmp_path, monkeypatch):
|
||||||
|
"""load_pyproject returns the [tool.fastapi-toolsets] section."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text(
|
||||||
|
'[tool.fastapi-toolsets]\nfixtures = "app.fixtures:registry"\n'
|
||||||
|
)
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = load_pyproject()
|
||||||
|
assert result == {"fixtures": "app.fixtures:registry"}
|
||||||
|
|
||||||
|
def test_load_pyproject_empty_when_no_file(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns empty dict when no pyproject.toml exists."""
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
result = load_pyproject()
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
def test_load_pyproject_empty_when_no_tool_section(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns empty dict when no [tool.fastapi-toolsets] section."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[project]\nname = 'test'\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = load_pyproject()
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
def test_load_pyproject_invalid_toml(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns empty dict when pyproject.toml is invalid."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("invalid toml {{{")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = load_pyproject()
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestImportFromString:
|
||||||
|
"""Tests for import_from_string function."""
|
||||||
|
|
||||||
|
def test_import_valid_path(self):
|
||||||
|
"""Import valid module:attribute path."""
|
||||||
|
result = import_from_string("fastapi_toolsets.fixtures:FixtureRegistry")
|
||||||
|
assert result is FixtureRegistry
|
||||||
|
|
||||||
|
def test_import_without_colon_raises_error(self):
|
||||||
|
"""Import path without colon raises error."""
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
import_from_string("fastapi_toolsets.fixtures.FixtureRegistry")
|
||||||
|
assert "Expected format: 'module:attribute'" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_import_nonexistent_module_raises_error(self):
|
||||||
|
"""Import nonexistent module raises error."""
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
import_from_string("nonexistent.module:something")
|
||||||
|
assert "Cannot import module" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_import_nonexistent_attribute_raises_error(self):
|
||||||
|
"""Import nonexistent attribute raises error."""
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
import_from_string("fastapi_toolsets.fixtures:NonexistentClass")
|
||||||
|
assert "has no attribute" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetConfigValue:
|
||||||
|
"""Tests for get_config_value function."""
|
||||||
|
|
||||||
|
def test_get_existing_value(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns value when key exists."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text('[tool.fastapi-toolsets]\nfixtures = "app:registry"\n')
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = get_config_value("fixtures")
|
||||||
|
assert result == "app:registry"
|
||||||
|
|
||||||
|
def test_get_missing_value_returns_none(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns None when key is missing and not required."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = get_config_value("fixtures")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_get_missing_value_required_raises_error(self, tmp_path, monkeypatch):
|
||||||
|
"""Raises error when key is missing and required."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
get_config_value("fixtures", required=True)
|
||||||
|
assert "No 'fixtures' configured" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetFixturesRegistry:
|
||||||
|
"""Tests for get_fixtures_registry function."""
|
||||||
|
|
||||||
|
def test_raises_when_not_configured(self, tmp_path, monkeypatch):
|
||||||
|
"""Raises error when fixtures not configured."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
get_fixtures_registry()
|
||||||
|
assert "No 'fixtures' configured" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_raises_when_not_registry_instance(self, tmp_path, monkeypatch):
|
||||||
|
"""Raises error when imported object is not a FixtureRegistry."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text(
|
||||||
|
'[tool.fastapi-toolsets]\nfixtures = "my_fixtures:registry"\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
fixtures_file = tmp_path / "my_fixtures.py"
|
||||||
|
fixtures_file.write_text("registry = 'not a registry'\n")
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
if str(tmp_path) not in sys.path:
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
get_fixtures_registry()
|
||||||
|
assert "must be a FixtureRegistry instance" in str(exc_info.value)
|
||||||
|
finally:
|
||||||
|
if str(tmp_path) in sys.path:
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
if "my_fixtures" in sys.modules:
|
||||||
|
del sys.modules["my_fixtures"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetDbContext:
|
||||||
|
"""Tests for get_db_context function."""
|
||||||
|
|
||||||
|
def test_raises_when_not_configured(self, tmp_path, monkeypatch):
|
||||||
|
"""Raises error when db_context not configured."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
get_db_context()
|
||||||
|
assert "No 'db_context' configured" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetCustomCli:
|
||||||
|
"""Tests for get_custom_cli function."""
|
||||||
|
|
||||||
|
def test_returns_none_when_not_configured(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns None when custom_cli not configured."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = get_custom_cli()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_raises_when_not_typer_instance(self, tmp_path, monkeypatch):
|
||||||
|
"""Raises error when imported object is not a Typer instance."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text('[tool.fastapi-toolsets]\ncustom_cli = "my_cli:cli"\n')
|
||||||
|
|
||||||
|
cli_file = tmp_path / "my_cli.py"
|
||||||
|
cli_file.write_text("cli = 'not a typer'\n")
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
if str(tmp_path) not in sys.path:
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
get_custom_cli()
|
||||||
|
assert "must be a Typer instance" in str(exc_info.value)
|
||||||
|
finally:
|
||||||
|
if str(tmp_path) in sys.path:
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
if "my_cli" in sys.modules:
|
||||||
|
del sys.modules["my_cli"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestCliApp:
|
||||||
|
"""Tests for CLI application."""
|
||||||
|
|
||||||
|
def test_cli_help(self, tmp_path, monkeypatch):
|
||||||
|
"""CLI shows help without fixtures."""
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
# Need to reload the module to pick up new cwd
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app
|
||||||
|
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
result = runner.invoke(app.cli, ["--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "CLI utilities for FastAPI projects" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
class TestFixturesCli:
|
||||||
|
"""Tests for fixtures CLI commands."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def cli_env(self, tmp_path, monkeypatch):
|
||||||
|
"""Set up CLI environment with fixtures config."""
|
||||||
|
# Create pyproject.toml
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text(
|
||||||
|
"[tool.fastapi-toolsets]\n"
|
||||||
|
'fixtures = "fixtures:registry"\n'
|
||||||
|
'db_context = "db:get_session"\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create fixtures module
|
||||||
|
fixtures_file = tmp_path / "fixtures.py"
|
||||||
|
fixtures_file.write_text(
|
||||||
|
"from fastapi_toolsets.fixtures import FixtureRegistry, Context\n"
|
||||||
|
"\n"
|
||||||
|
"registry = FixtureRegistry()\n"
|
||||||
|
"\n"
|
||||||
|
"@registry.register(contexts=[Context.BASE])\n"
|
||||||
|
"def roles():\n"
|
||||||
|
' return [{"id": 1, "name": "admin"}, {"id": 2, "name": "user"}]\n'
|
||||||
|
"\n"
|
||||||
|
'@registry.register(depends_on=["roles"], contexts=[Context.TESTING])\n'
|
||||||
|
"def users():\n"
|
||||||
|
' return [{"id": 1, "name": "alice", "role_id": 1}]\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create db module
|
||||||
|
db_file = tmp_path / "db.py"
|
||||||
|
db_file.write_text(
|
||||||
|
"from contextlib import asynccontextmanager\n"
|
||||||
|
"\n"
|
||||||
|
"@asynccontextmanager\n"
|
||||||
|
"async def get_session():\n"
|
||||||
|
" yield None\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
# Add tmp_path to sys.path for imports
|
||||||
|
if str(tmp_path) not in sys.path:
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
# Reload the CLI module to pick up new config
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app
|
||||||
|
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
yield tmp_path, app.cli
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
if str(tmp_path) in sys.path:
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
|
||||||
|
def test_fixtures_list(self, cli_env):
|
||||||
|
"""fixtures list shows registered fixtures."""
|
||||||
|
tmp_path, cli = cli_env
|
||||||
|
result = runner.invoke(cli, ["fixtures", "list"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "roles" in result.output
|
||||||
|
assert "users" in result.output
|
||||||
|
assert "Total: 2 fixture(s)" in result.output
|
||||||
|
|
||||||
|
def test_fixtures_list_with_context(self, cli_env):
|
||||||
|
"""fixtures list --context filters by context."""
|
||||||
|
tmp_path, cli = cli_env
|
||||||
|
result = runner.invoke(cli, ["fixtures", "list", "--context", "base"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "roles" in result.output
|
||||||
|
assert "users" not in result.output
|
||||||
|
assert "Total: 1 fixture(s)" in result.output
|
||||||
|
|
||||||
|
def test_fixtures_load_dry_run(self, cli_env):
|
||||||
|
"""fixtures load --dry-run shows what would be loaded."""
|
||||||
|
tmp_path, cli = cli_env
|
||||||
|
result = runner.invoke(cli, ["fixtures", "load", "base", "--dry-run"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Fixtures to load" in result.output
|
||||||
|
assert "roles" in result.output
|
||||||
|
assert "[Dry run - no changes made]" in result.output
|
||||||
|
|
||||||
|
def test_fixtures_load_invalid_strategy(self, cli_env):
|
||||||
|
"""fixtures load with invalid strategy shows error."""
|
||||||
|
tmp_path, cli = cli_env
|
||||||
|
result = runner.invoke(
|
||||||
|
cli, ["fixtures", "load", "base", "--strategy", "invalid"]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code != 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestCliWithoutFixturesConfig:
|
||||||
|
"""Tests for CLI when fixtures is not configured."""
|
||||||
|
|
||||||
|
def test_no_fixtures_command(self, tmp_path, monkeypatch):
|
||||||
|
"""fixtures command is not available when not configured."""
|
||||||
|
# Create pyproject.toml without fixtures
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text('[project]\nname = "test"\n')
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
# Reload the CLI module
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app
|
||||||
|
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
result = runner.invoke(app.cli, ["--help"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "fixtures" not in result.output
|
||||||
|
|
||||||
|
|
||||||
|
class TestCustomCliConfig:
|
||||||
|
"""Tests for custom CLI configuration."""
|
||||||
|
|
||||||
|
def test_cli_with_custom_cli(self, tmp_path, monkeypatch):
|
||||||
|
"""CLI uses custom Typer instance when configured."""
|
||||||
|
import typer
|
||||||
|
|
||||||
|
# Create pyproject.toml with custom_cli config
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text('[tool.fastapi-toolsets]\ncustom_cli = "my_cli:cli"\n')
|
||||||
|
|
||||||
|
# Create custom CLI module with its own Typer and commands
|
||||||
|
cli_file = tmp_path / "my_cli.py"
|
||||||
|
cli_file.write_text(
|
||||||
|
"import typer\n"
|
||||||
|
"\n"
|
||||||
|
"cli = typer.Typer(name='my-app', help='My custom CLI')\n"
|
||||||
|
"\n"
|
||||||
|
"@cli.command()\n"
|
||||||
|
"def hello():\n"
|
||||||
|
' print("Hello from custom CLI!")\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
# Add tmp_path to sys.path for imports
|
||||||
|
if str(tmp_path) not in sys.path:
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
# Remove my_cli from sys.modules if it was previously loaded
|
||||||
|
if "my_cli" in sys.modules:
|
||||||
|
del sys.modules["my_cli"]
|
||||||
|
|
||||||
|
# Reload the CLI module to pick up new config
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app
|
||||||
|
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Verify custom CLI is used
|
||||||
|
assert isinstance(app.cli, typer.Typer)
|
||||||
|
|
||||||
|
result = runner.invoke(app.cli, ["--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "My custom CLI" in result.output
|
||||||
|
assert "hello" in result.output
|
||||||
|
|
||||||
|
result = runner.invoke(app.cli, ["hello"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Hello from custom CLI!" in result.output
|
||||||
|
finally:
|
||||||
|
if str(tmp_path) in sys.path:
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
if "my_cli" in sys.modules:
|
||||||
|
del sys.modules["my_cli"]
|
||||||
|
|
||||||
|
def test_custom_cli_with_fixtures(self, tmp_path, monkeypatch):
|
||||||
|
"""Custom CLI gets fixtures command added when configured."""
|
||||||
|
# Create pyproject.toml with both custom_cli and fixtures
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text(
|
||||||
|
"[tool.fastapi-toolsets]\n"
|
||||||
|
'custom_cli = "my_cli:cli"\n'
|
||||||
|
'fixtures = "fixtures:registry"\n'
|
||||||
|
'db_context = "db:get_session"\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create custom CLI module
|
||||||
|
cli_file = tmp_path / "my_cli.py"
|
||||||
|
cli_file.write_text(
|
||||||
|
"import typer\n"
|
||||||
|
"\n"
|
||||||
|
"cli = typer.Typer(name='my-app', help='My custom CLI')\n"
|
||||||
|
"\n"
|
||||||
|
"@cli.command()\n"
|
||||||
|
"def hello():\n"
|
||||||
|
' print("Hello!")\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create fixtures module
|
||||||
|
fixtures_file = tmp_path / "fixtures.py"
|
||||||
|
fixtures_file.write_text(
|
||||||
|
"from fastapi_toolsets.fixtures import FixtureRegistry\n"
|
||||||
|
"\n"
|
||||||
|
"registry = FixtureRegistry()\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create db module
|
||||||
|
db_file = tmp_path / "db.py"
|
||||||
|
db_file.write_text(
|
||||||
|
"from contextlib import asynccontextmanager\n"
|
||||||
|
"\n"
|
||||||
|
"@asynccontextmanager\n"
|
||||||
|
"async def get_session():\n"
|
||||||
|
" yield None\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
if str(tmp_path) not in sys.path:
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
for mod in ["my_cli", "fixtures", "db"]:
|
||||||
|
if mod in sys.modules:
|
||||||
|
del sys.modules[mod]
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app
|
||||||
|
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = runner.invoke(app.cli, ["--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
# Should have both custom command and fixtures
|
||||||
|
assert "hello" in result.output
|
||||||
|
assert "fixtures" in result.output
|
||||||
|
finally:
|
||||||
|
if str(tmp_path) in sys.path:
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
for mod in ["my_cli", "fixtures", "db"]:
|
||||||
|
if mod in sys.modules:
|
||||||
|
del sys.modules[mod]
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncCommand:
|
||||||
|
"""Tests for async_command decorator."""
|
||||||
|
|
||||||
|
def test_async_command_runs_coroutine(self):
|
||||||
|
"""async_command runs async function synchronously."""
|
||||||
|
|
||||||
|
@async_command
|
||||||
|
async def async_func(value: int) -> int:
|
||||||
|
return value * 2
|
||||||
|
|
||||||
|
result = async_func(21)
|
||||||
|
assert result == 42
|
||||||
|
|
||||||
|
def test_async_command_preserves_signature(self):
|
||||||
|
"""async_command preserves function signature."""
|
||||||
|
|
||||||
|
@async_command
|
||||||
|
async def async_func(name: str, count: int = 1) -> str:
|
||||||
|
return f"{name} x {count}"
|
||||||
|
|
||||||
|
result = async_func("test", count=3)
|
||||||
|
assert result == "test x 3"
|
||||||
|
|
||||||
|
def test_async_command_preserves_docstring(self):
|
||||||
|
"""async_command preserves function docstring."""
|
||||||
|
|
||||||
|
@async_command
|
||||||
|
async def async_func() -> None:
|
||||||
|
"""This is a docstring."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert async_func.__doc__ == """This is a docstring."""
|
||||||
|
|
||||||
|
def test_async_command_preserves_name(self):
|
||||||
|
"""async_command preserves function name."""
|
||||||
|
|
||||||
|
@async_command
|
||||||
|
async def my_async_function() -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert my_async_function.__name__ == "my_async_function"
|
||||||
1704
tests/test_crud.py
1704
tests/test_crud.py
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,20 @@
|
|||||||
"""Tests for CRUD search functionality."""
|
"""Tests for CRUD search functionality."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.sql.elements import ColumnElement, UnaryExpression
|
||||||
|
|
||||||
from fastapi_toolsets.crud import SearchConfig, get_searchable_fields
|
from fastapi_toolsets.crud import (
|
||||||
|
CrudFactory,
|
||||||
|
InvalidFacetFilterError,
|
||||||
|
SearchConfig,
|
||||||
|
get_searchable_fields,
|
||||||
|
)
|
||||||
|
from fastapi_toolsets.exceptions import InvalidOrderFieldError
|
||||||
|
from fastapi_toolsets.schemas import OffsetPagination
|
||||||
|
|
||||||
from .conftest import (
|
from .conftest import (
|
||||||
Role,
|
Role,
|
||||||
@@ -37,7 +48,8 @@ class TestPaginateSearch:
|
|||||||
search_fields=[User.username],
|
search_fields=[User.username],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 2
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 2
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_search_multiple_columns(self, db_session: AsyncSession):
|
async def test_search_multiple_columns(self, db_session: AsyncSession):
|
||||||
@@ -55,7 +67,8 @@ class TestPaginateSearch:
|
|||||||
search_fields=[User.username, User.email],
|
search_fields=[User.username, User.email],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 2
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 2
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_search_relationship_depth1(self, db_session: AsyncSession):
|
async def test_search_relationship_depth1(self, db_session: AsyncSession):
|
||||||
@@ -82,7 +95,8 @@ class TestPaginateSearch:
|
|||||||
search_fields=[(User.role, Role.name)],
|
search_fields=[(User.role, Role.name)],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 2
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 2
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_search_mixed_direct_and_relation(self, db_session: AsyncSession):
|
async def test_search_mixed_direct_and_relation(self, db_session: AsyncSession):
|
||||||
@@ -100,7 +114,8 @@ class TestPaginateSearch:
|
|||||||
search_fields=[User.username, (User.role, Role.name)],
|
search_fields=[User.username, (User.role, Role.name)],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 1
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 1
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_search_case_insensitive(self, db_session: AsyncSession):
|
async def test_search_case_insensitive(self, db_session: AsyncSession):
|
||||||
@@ -115,7 +130,8 @@ class TestPaginateSearch:
|
|||||||
search_fields=[User.username],
|
search_fields=[User.username],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 1
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 1
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_search_case_sensitive(self, db_session: AsyncSession):
|
async def test_search_case_sensitive(self, db_session: AsyncSession):
|
||||||
@@ -130,7 +146,8 @@ class TestPaginateSearch:
|
|||||||
search=SearchConfig(query="johndoe", case_sensitive=True),
|
search=SearchConfig(query="johndoe", case_sensitive=True),
|
||||||
search_fields=[User.username],
|
search_fields=[User.username],
|
||||||
)
|
)
|
||||||
assert result["pagination"]["total_count"] == 0
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 0
|
||||||
|
|
||||||
# Should find (case match)
|
# Should find (case match)
|
||||||
result = await UserCrud.paginate(
|
result = await UserCrud.paginate(
|
||||||
@@ -138,7 +155,8 @@ class TestPaginateSearch:
|
|||||||
search=SearchConfig(query="JohnDoe", case_sensitive=True),
|
search=SearchConfig(query="JohnDoe", case_sensitive=True),
|
||||||
search_fields=[User.username],
|
search_fields=[User.username],
|
||||||
)
|
)
|
||||||
assert result["pagination"]["total_count"] == 1
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 1
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_search_empty_query(self, db_session: AsyncSession):
|
async def test_search_empty_query(self, db_session: AsyncSession):
|
||||||
@@ -151,10 +169,12 @@ class TestPaginateSearch:
|
|||||||
)
|
)
|
||||||
|
|
||||||
result = await UserCrud.paginate(db_session, search="")
|
result = await UserCrud.paginate(db_session, search="")
|
||||||
assert result["pagination"]["total_count"] == 2
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 2
|
||||||
|
|
||||||
result = await UserCrud.paginate(db_session, search=None)
|
result = await UserCrud.paginate(db_session, search=None)
|
||||||
assert result["pagination"]["total_count"] == 2
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 2
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_search_with_existing_filters(self, db_session: AsyncSession):
|
async def test_search_with_existing_filters(self, db_session: AsyncSession):
|
||||||
@@ -175,8 +195,9 @@ class TestPaginateSearch:
|
|||||||
search_fields=[User.username],
|
search_fields=[User.username],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 1
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
assert result["data"][0].username == "active_john"
|
assert result.pagination.total_count == 1
|
||||||
|
assert result.data[0].username == "active_john"
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_search_auto_detect_fields(self, db_session: AsyncSession):
|
async def test_search_auto_detect_fields(self, db_session: AsyncSession):
|
||||||
@@ -187,7 +208,8 @@ class TestPaginateSearch:
|
|||||||
|
|
||||||
result = await UserCrud.paginate(db_session, search="findme")
|
result = await UserCrud.paginate(db_session, search="findme")
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 1
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 1
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_search_no_results(self, db_session: AsyncSession):
|
async def test_search_no_results(self, db_session: AsyncSession):
|
||||||
@@ -202,8 +224,9 @@ class TestPaginateSearch:
|
|||||||
search_fields=[User.username],
|
search_fields=[User.username],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 0
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
assert result["data"] == []
|
assert result.pagination.total_count == 0
|
||||||
|
assert result.data == []
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_search_with_pagination(self, db_session: AsyncSession):
|
async def test_search_with_pagination(self, db_session: AsyncSession):
|
||||||
@@ -222,9 +245,10 @@ class TestPaginateSearch:
|
|||||||
items_per_page=5,
|
items_per_page=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 15
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
assert len(result["data"]) == 5
|
assert result.pagination.total_count == 15
|
||||||
assert result["pagination"]["has_more"] is True
|
assert len(result.data) == 5
|
||||||
|
assert result.pagination.has_more is True
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_search_null_relationship(self, db_session: AsyncSession):
|
async def test_search_null_relationship(self, db_session: AsyncSession):
|
||||||
@@ -246,7 +270,8 @@ class TestPaginateSearch:
|
|||||||
search_fields=[User.username],
|
search_fields=[User.username],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 2
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 2
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_search_with_order_by(self, db_session: AsyncSession):
|
async def test_search_with_order_by(self, db_session: AsyncSession):
|
||||||
@@ -268,14 +293,63 @@ class TestPaginateSearch:
|
|||||||
order_by=User.username,
|
order_by=User.username,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 3
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
usernames = [u.username for u in result["data"]]
|
assert result.pagination.total_count == 3
|
||||||
|
usernames = [u.username for u in result.data]
|
||||||
assert usernames == ["alice", "bob", "charlie"]
|
assert usernames == ["alice", "bob", "charlie"]
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_non_string_column(self, db_session: AsyncSession):
|
||||||
|
"""Search on non-string columns (e.g., UUID) works via cast."""
|
||||||
|
user_id = uuid.UUID("12345678-1234-5678-1234-567812345678")
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(id=user_id, username="john", email="john@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="jane", email="jane@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search by UUID (partial match)
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search="12345678",
|
||||||
|
search_fields=[User.id, User.username],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 1
|
||||||
|
assert result.data[0].id == user_id
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildSearchFilters:
|
||||||
|
"""Unit tests for build_search_filters."""
|
||||||
|
|
||||||
|
def test_deduplicates_relationship_join(self):
|
||||||
|
"""Two tuple fields sharing the same relationship do not add the join twice."""
|
||||||
|
from fastapi_toolsets.crud.search import build_search_filters
|
||||||
|
|
||||||
|
# Both fields traverse User.role — the second must not re-add the join.
|
||||||
|
filters, joins = build_search_filters(
|
||||||
|
User,
|
||||||
|
"admin",
|
||||||
|
search_fields=[(User.role, Role.name), (User.role, Role.id)],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(joins) == 1
|
||||||
|
|
||||||
|
|
||||||
class TestSearchConfig:
|
class TestSearchConfig:
|
||||||
"""Tests for SearchConfig options."""
|
"""Tests for SearchConfig options."""
|
||||||
|
|
||||||
|
def test_search_config_empty_query_returns_empty(self):
|
||||||
|
"""SearchConfig with an empty/blank query returns empty filters without hitting the DB."""
|
||||||
|
from fastapi_toolsets.crud.search import build_search_filters
|
||||||
|
|
||||||
|
filters, joins = build_search_filters(User, SearchConfig(query=" "))
|
||||||
|
|
||||||
|
assert filters == []
|
||||||
|
assert joins == []
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_match_mode_all(self, db_session: AsyncSession):
|
async def test_match_mode_all(self, db_session: AsyncSession):
|
||||||
"""match_mode='all' requires all fields to match (AND)."""
|
"""match_mode='all' requires all fields to match (AND)."""
|
||||||
@@ -295,8 +369,9 @@ class TestSearchConfig:
|
|||||||
search_fields=[User.username, User.email],
|
search_fields=[User.username, User.email],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 1
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
assert result["data"][0].username == "john_test"
|
assert result.pagination.total_count == 1
|
||||||
|
assert result.data[0].username == "john_test"
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_search_config_with_fields(self, db_session: AsyncSession):
|
async def test_search_config_with_fields(self, db_session: AsyncSession):
|
||||||
@@ -310,7 +385,8 @@ class TestSearchConfig:
|
|||||||
search=SearchConfig(query="findme", fields=[User.email]),
|
search=SearchConfig(query="findme", fields=[User.email]),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 1
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 1
|
||||||
|
|
||||||
|
|
||||||
class TestNoSearchableFieldsError:
|
class TestNoSearchableFieldsError:
|
||||||
@@ -390,3 +466,695 @@ class TestGetSearchableFields:
|
|||||||
# Role.users is a collection, should not be included
|
# Role.users is a collection, should not be included
|
||||||
field_strs = [str(f) for f in fields]
|
field_strs = [str(f) for f in fields]
|
||||||
assert not any("users" in f for f in field_strs)
|
assert not any("users" in f for f in field_strs)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFacetsNotSet:
|
||||||
|
"""filter_attributes is None when no facet_fields are configured."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_offset_paginate_no_facets(self, db_session: AsyncSession):
|
||||||
|
"""filter_attributes is None when facet_fields not set on factory or call."""
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="a@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserCrud.offset_paginate(db_session)
|
||||||
|
|
||||||
|
assert result.filter_attributes is None
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_cursor_paginate_no_facets(self, db_session: AsyncSession):
|
||||||
|
"""filter_attributes is None for cursor_paginate when facet_fields not set."""
|
||||||
|
UserCursorCrud = CrudFactory(User, cursor_column=User.id)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="a@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserCursorCrud.cursor_paginate(db_session)
|
||||||
|
|
||||||
|
assert result.filter_attributes is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestFacetsDirectColumn:
|
||||||
|
"""Facets on direct model columns."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_offset_paginate_direct_column(self, db_session: AsyncSession):
|
||||||
|
"""Returns distinct values for a direct column via factory default."""
|
||||||
|
UserFacetCrud = CrudFactory(User, facet_fields=[User.username])
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="a@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="bob", email="b@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserFacetCrud.offset_paginate(db_session)
|
||||||
|
|
||||||
|
assert result.filter_attributes is not None
|
||||||
|
# Distinct usernames, sorted
|
||||||
|
assert result.filter_attributes["username"] == ["alice", "bob"]
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_cursor_paginate_direct_column(self, db_session: AsyncSession):
|
||||||
|
"""Returns distinct values for a direct column in cursor_paginate."""
|
||||||
|
UserFacetCursorCrud = CrudFactory(
|
||||||
|
User, cursor_column=User.id, facet_fields=[User.email]
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="a@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="bob", email="b@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserFacetCursorCrud.cursor_paginate(db_session)
|
||||||
|
|
||||||
|
assert result.filter_attributes is not None
|
||||||
|
assert set(result.filter_attributes["email"]) == {"a@test.com", "b@test.com"}
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_multiple_facet_columns(self, db_session: AsyncSession):
|
||||||
|
"""Returns distinct values for multiple columns."""
|
||||||
|
UserFacetCrud = CrudFactory(User, facet_fields=[User.username, User.email])
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="a@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="bob", email="b@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserFacetCrud.offset_paginate(db_session)
|
||||||
|
|
||||||
|
assert result.filter_attributes is not None
|
||||||
|
assert "username" in result.filter_attributes
|
||||||
|
assert "email" in result.filter_attributes
|
||||||
|
assert set(result.filter_attributes["username"]) == {"alice", "bob"}
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_per_call_override(self, db_session: AsyncSession):
|
||||||
|
"""Per-call facet_fields overrides the factory default."""
|
||||||
|
UserFacetCrud = CrudFactory(User, facet_fields=[User.username])
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="a@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Override: ask for email instead of username
|
||||||
|
result = await UserFacetCrud.offset_paginate(
|
||||||
|
db_session, facet_fields=[User.email]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.filter_attributes is not None
|
||||||
|
assert "email" in result.filter_attributes
|
||||||
|
assert "username" not in result.filter_attributes
|
||||||
|
|
||||||
|
|
||||||
|
class TestFacetsRespectFilters:
|
||||||
|
"""Facets reflect the active filter conditions."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_facets_respect_base_filters(self, db_session: AsyncSession):
|
||||||
|
"""Facet values are scoped to the applied filters."""
|
||||||
|
UserFacetCrud = CrudFactory(User, facet_fields=[User.username])
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="a@test.com", is_active=True)
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="bob", email="b@test.com", is_active=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter to active users only — facets should only see "alice"
|
||||||
|
result = await UserFacetCrud.offset_paginate(
|
||||||
|
db_session,
|
||||||
|
filters=[User.is_active == True], # noqa: E712
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.filter_attributes is not None
|
||||||
|
assert result.filter_attributes["username"] == ["alice"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestFacetsRelationship:
|
||||||
|
"""Facets on relationship columns via tuple syntax."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_relationship_facet(self, db_session: AsyncSession):
|
||||||
|
"""Returns distinct values from a related model column."""
|
||||||
|
UserRelFacetCrud = CrudFactory(User, facet_fields=[(User.role, Role.name)])
|
||||||
|
|
||||||
|
admin = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||||
|
editor = await RoleCrud.create(db_session, RoleCreate(name="editor"))
|
||||||
|
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="alice", email="a@test.com", role_id=admin.id),
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="bob", email="b@test.com", role_id=editor.id),
|
||||||
|
)
|
||||||
|
# User without a role — their role.name should be excluded (None filtered out)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="charlie", email="c@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserRelFacetCrud.offset_paginate(db_session)
|
||||||
|
|
||||||
|
assert result.filter_attributes is not None
|
||||||
|
assert set(result.filter_attributes["name"]) == {"admin", "editor"}
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_relationship_facet_none_excluded(self, db_session: AsyncSession):
|
||||||
|
"""None values (e.g. NULL role) are excluded from facet results."""
|
||||||
|
UserRelFacetCrud = CrudFactory(User, facet_fields=[(User.role, Role.name)])
|
||||||
|
|
||||||
|
# Only user with no role
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="norole", email="n@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserRelFacetCrud.offset_paginate(db_session)
|
||||||
|
|
||||||
|
assert result.filter_attributes is not None
|
||||||
|
assert result.filter_attributes["name"] == []
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_relationship_facet_deduplicates_join_with_search(
|
||||||
|
self, db_session: AsyncSession
|
||||||
|
):
|
||||||
|
"""Facet join is not duplicated when search already added the same relationship join."""
|
||||||
|
# Both search and facet use (User.role, Role.name) — join should not be doubled
|
||||||
|
UserSearchFacetCrud = CrudFactory(
|
||||||
|
User,
|
||||||
|
searchable_fields=[(User.role, Role.name)],
|
||||||
|
facet_fields=[(User.role, Role.name)],
|
||||||
|
)
|
||||||
|
|
||||||
|
admin = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="alice", email="a@test.com", role_id=admin.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserSearchFacetCrud.offset_paginate(
|
||||||
|
db_session, search="admin", search_fields=[(User.role, Role.name)]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.filter_attributes is not None
|
||||||
|
assert result.filter_attributes["name"] == ["admin"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestFilterBy:
|
||||||
|
"""Tests for the filter_by parameter on offset_paginate and cursor_paginate."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_scalar_filter(self, db_session: AsyncSession):
|
||||||
|
"""filter_by with a scalar value produces an equality filter."""
|
||||||
|
UserFacetCrud = CrudFactory(User, facet_fields=[User.username])
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="a@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="bob", email="b@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserFacetCrud.offset_paginate(
|
||||||
|
db_session, filter_by={"username": "alice"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result.data) == 1
|
||||||
|
assert result.data[0].username == "alice"
|
||||||
|
# facet also scoped to the filter
|
||||||
|
assert result.filter_attributes == {"username": ["alice"]}
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_list_filter_produces_in_clause(self, db_session: AsyncSession):
|
||||||
|
"""filter_by with a list value produces an IN filter."""
|
||||||
|
UserFacetCrud = CrudFactory(User, facet_fields=[User.username])
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="a@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="bob", email="b@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="charlie", email="c@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserFacetCrud.offset_paginate(
|
||||||
|
db_session, filter_by={"username": ["alice", "bob"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 2
|
||||||
|
returned_names = {u.username for u in result.data}
|
||||||
|
assert returned_names == {"alice", "bob"}
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_relationship_filter_by(self, db_session: AsyncSession):
|
||||||
|
"""filter_by works with relationship tuple facet fields."""
|
||||||
|
UserRelFacetCrud = CrudFactory(User, facet_fields=[(User.role, Role.name)])
|
||||||
|
|
||||||
|
admin = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||||
|
editor = await RoleCrud.create(db_session, RoleCreate(name="editor"))
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="alice", email="a@test.com", role_id=admin.id),
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="bob", email="b@test.com", role_id=editor.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserRelFacetCrud.offset_paginate(
|
||||||
|
db_session, filter_by={"name": "admin"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 1
|
||||||
|
assert result.data[0].username == "alice"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_filter_by_combined_with_filters(self, db_session: AsyncSession):
|
||||||
|
"""filter_by and filters= are combined (AND logic)."""
|
||||||
|
UserFacetCrud = CrudFactory(User, facet_fields=[User.username])
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="a@test.com", is_active=True)
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="alice2", email="a2@test.com", is_active=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserFacetCrud.offset_paginate(
|
||||||
|
db_session,
|
||||||
|
filters=[User.is_active == True], # noqa: E712
|
||||||
|
filter_by={"username": ["alice", "alice2"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only alice passes both: is_active=True AND username IN [alice, alice2]
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 1
|
||||||
|
assert result.data[0].username == "alice"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_invalid_key_raises(self, db_session: AsyncSession):
|
||||||
|
"""filter_by with an undeclared key raises InvalidFacetFilterError."""
|
||||||
|
UserFacetCrud = CrudFactory(User, facet_fields=[User.username])
|
||||||
|
|
||||||
|
with pytest.raises(InvalidFacetFilterError) as exc_info:
|
||||||
|
await UserFacetCrud.offset_paginate(
|
||||||
|
db_session, filter_by={"nonexistent": "value"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exc_info.value.key == "nonexistent"
|
||||||
|
assert "username" in exc_info.value.valid_keys
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_filter_by_deduplicates_relationship_join(
|
||||||
|
self, db_session: AsyncSession
|
||||||
|
):
|
||||||
|
"""Two filter_by keys through the same relationship do not duplicate the join."""
|
||||||
|
# Both (User.role, Role.name) and (User.role, Role.id) traverse User.role —
|
||||||
|
# the second key must not re-add the join (exercises the dedup branch in build_filter_by).
|
||||||
|
UserRoleFacetCrud = CrudFactory(
|
||||||
|
User,
|
||||||
|
facet_fields=[(User.role, Role.name), (User.role, Role.id)],
|
||||||
|
)
|
||||||
|
|
||||||
|
admin = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||||
|
editor = await RoleCrud.create(db_session, RoleCreate(name="editor"))
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="alice", email="a@test.com", role_id=admin.id),
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="bob", email="b@test.com", role_id=editor.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserRoleFacetCrud.offset_paginate(
|
||||||
|
db_session,
|
||||||
|
filter_by={"name": "admin", "id": str(admin.id)},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 1
|
||||||
|
assert result.data[0].username == "alice"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_cursor_paginate_filter_by(self, db_session: AsyncSession):
|
||||||
|
"""filter_by works with cursor_paginate."""
|
||||||
|
UserFacetCursorCrud = CrudFactory(
|
||||||
|
User, cursor_column=User.id, facet_fields=[User.username]
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="a@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="bob", email="b@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserFacetCursorCrud.cursor_paginate(
|
||||||
|
db_session, filter_by={"username": "alice"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result.data) == 1
|
||||||
|
assert result.data[0].username == "alice"
|
||||||
|
assert result.filter_attributes == {"username": ["alice"]}
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_basemodel_filter_by_offset_paginate(self, db_session: AsyncSession):
|
||||||
|
"""filter_by accepts a BaseModel instance (model_dump path) in offset_paginate."""
|
||||||
|
from pydantic import BaseModel as PydanticBaseModel
|
||||||
|
|
||||||
|
class UserFilter(PydanticBaseModel):
|
||||||
|
username: str | None = None
|
||||||
|
|
||||||
|
UserFacetCrud = CrudFactory(User, facet_fields=[User.username])
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="a@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="bob", email="b@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserFacetCrud.offset_paginate(
|
||||||
|
db_session, filter_by=UserFilter(username="alice")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 1
|
||||||
|
assert result.data[0].username == "alice"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_basemodel_filter_by_cursor_paginate(self, db_session: AsyncSession):
|
||||||
|
"""filter_by accepts a BaseModel instance (model_dump path) in cursor_paginate."""
|
||||||
|
from pydantic import BaseModel as PydanticBaseModel
|
||||||
|
|
||||||
|
class UserFilter(PydanticBaseModel):
|
||||||
|
username: str | None = None
|
||||||
|
|
||||||
|
UserFacetCursorCrud = CrudFactory(
|
||||||
|
User, cursor_column=User.id, facet_fields=[User.username]
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="a@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="bob", email="b@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserFacetCursorCrud.cursor_paginate(
|
||||||
|
db_session, filter_by=UserFilter(username="alice")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result.data) == 1
|
||||||
|
assert result.data[0].username == "alice"
|
||||||
|
|
||||||
|
|
||||||
|
class TestFilterParamsSchema:
|
||||||
|
"""Tests for AsyncCrud.filter_params()."""
|
||||||
|
|
||||||
|
def test_generates_fields_from_facet_fields(self):
|
||||||
|
"""Returned dependency has one keyword param per facet field."""
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
UserFacetCrud = CrudFactory(User, facet_fields=[User.username, User.email])
|
||||||
|
dep = UserFacetCrud.filter_params()
|
||||||
|
|
||||||
|
param_names = set(inspect.signature(dep).parameters)
|
||||||
|
assert param_names == {"username", "email"}
|
||||||
|
|
||||||
|
def test_relationship_facet_uses_column_key(self):
|
||||||
|
"""Relationship tuple uses the terminal column's key."""
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
UserRoleCrud = CrudFactory(User, facet_fields=[(User.role, Role.name)])
|
||||||
|
dep = UserRoleCrud.filter_params()
|
||||||
|
|
||||||
|
param_names = set(inspect.signature(dep).parameters)
|
||||||
|
assert param_names == {"name"}
|
||||||
|
|
||||||
|
def test_raises_when_no_facet_fields(self):
|
||||||
|
"""ValueError raised when no facet_fields are configured or provided."""
|
||||||
|
with pytest.raises(ValueError, match="no facet_fields"):
|
||||||
|
UserCrud.filter_params()
|
||||||
|
|
||||||
|
def test_facet_fields_override(self):
|
||||||
|
"""facet_fields= parameter overrides the class-level default."""
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
UserFacetCrud = CrudFactory(User, facet_fields=[User.username, User.email])
|
||||||
|
dep = UserFacetCrud.filter_params(facet_fields=[User.email])
|
||||||
|
|
||||||
|
param_names = set(inspect.signature(dep).parameters)
|
||||||
|
assert param_names == {"email"}
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_awaiting_dep_returns_dict_with_values(self):
|
||||||
|
"""Awaiting the dependency returns a dict with only the supplied keys."""
|
||||||
|
UserFacetCrud = CrudFactory(User, facet_fields=[User.username, User.email])
|
||||||
|
dep = UserFacetCrud.filter_params()
|
||||||
|
|
||||||
|
result = await dep(username=["alice"])
|
||||||
|
assert result == {"username": ["alice"]}
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_multi_value_list_field(self):
|
||||||
|
"""Multiple values are accepted as a list."""
|
||||||
|
UserFacetCrud = CrudFactory(User, facet_fields=[User.username])
|
||||||
|
dep = UserFacetCrud.filter_params()
|
||||||
|
|
||||||
|
result = await dep(username=["alice", "bob"])
|
||||||
|
assert result == {"username": ["alice", "bob"]}
|
||||||
|
|
||||||
|
def test_disambiguates_duplicate_column_keys(self):
|
||||||
|
"""Two relationship tuples sharing a terminal column key get prefixed names."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from fastapi_toolsets.crud.search import facet_keys
|
||||||
|
|
||||||
|
col_a = MagicMock()
|
||||||
|
col_a.key = "name"
|
||||||
|
rel_a = MagicMock()
|
||||||
|
rel_a.key = "project"
|
||||||
|
|
||||||
|
col_b = MagicMock()
|
||||||
|
col_b.key = "name"
|
||||||
|
rel_b = MagicMock()
|
||||||
|
rel_b.key = "os"
|
||||||
|
|
||||||
|
keys = facet_keys([(rel_a, col_a), (rel_b, col_b)])
|
||||||
|
assert keys == ["project__name", "os__name"]
|
||||||
|
|
||||||
|
def test_unique_column_keys_kept_plain(self):
|
||||||
|
"""Fields with unique column keys are not prefixed."""
|
||||||
|
from fastapi_toolsets.crud.search import facet_keys
|
||||||
|
|
||||||
|
keys = facet_keys([User.username, User.email])
|
||||||
|
assert keys == ["username", "email"]
|
||||||
|
|
||||||
|
def test_dependency_name_includes_model_name(self):
|
||||||
|
"""Returned dependency is named {Model}FilterParams."""
|
||||||
|
UserFacetCrud = CrudFactory(User, facet_fields=[User.username])
|
||||||
|
dep = UserFacetCrud.filter_params()
|
||||||
|
|
||||||
|
assert dep.__name__ == "UserFilterParams" # type: ignore[union-attr]
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_integration_with_offset_paginate(self, db_session: AsyncSession):
|
||||||
|
"""Dependency result can be passed directly to offset_paginate via filter_by."""
|
||||||
|
UserFacetCrud = CrudFactory(User, facet_fields=[User.username])
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="a@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="bob", email="b@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
dep = UserFacetCrud.filter_params()
|
||||||
|
f = await dep(username=["alice"])
|
||||||
|
result = await UserFacetCrud.offset_paginate(db_session, filter_by=f)
|
||||||
|
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 1
|
||||||
|
assert result.data[0].username == "alice"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_dep_result_passed_to_cursor_paginate(self, db_session: AsyncSession):
|
||||||
|
"""Dependency result can be passed directly to cursor_paginate via filter_by."""
|
||||||
|
UserFacetCursorCrud = CrudFactory(
|
||||||
|
User, cursor_column=User.id, facet_fields=[User.username]
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="a@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="bob", email="b@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
dep = UserFacetCursorCrud.filter_params()
|
||||||
|
f = await dep(username=["alice"])
|
||||||
|
result = await UserFacetCursorCrud.cursor_paginate(db_session, filter_by=f)
|
||||||
|
|
||||||
|
assert len(result.data) == 1
|
||||||
|
assert result.data[0].username == "alice"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_all_none_dep_result_passes_no_filter(self, db_session: AsyncSession):
|
||||||
|
"""All-None dependency result results in no filter (returns all rows)."""
|
||||||
|
UserFacetCrud = CrudFactory(User, facet_fields=[User.username])
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="a@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="bob", email="b@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
dep = UserFacetCrud.filter_params()
|
||||||
|
f = await dep() # all fields None
|
||||||
|
result = await UserFacetCrud.offset_paginate(db_session, filter_by=f)
|
||||||
|
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrderParamsSchema:
|
||||||
|
"""Tests for AsyncCrud.order_params()."""
|
||||||
|
|
||||||
|
def test_generates_order_by_and_order_params(self):
|
||||||
|
"""Returned dependency has order_by and order query params."""
|
||||||
|
UserOrderCrud = CrudFactory(User, order_fields=[User.username, User.email])
|
||||||
|
dep = UserOrderCrud.order_params()
|
||||||
|
|
||||||
|
param_names = set(inspect.signature(dep).parameters)
|
||||||
|
assert param_names == {"order_by", "order"}
|
||||||
|
|
||||||
|
def test_dependency_name_includes_model_name(self):
|
||||||
|
"""Dependency function is named after the model."""
|
||||||
|
UserOrderCrud = CrudFactory(User, order_fields=[User.username])
|
||||||
|
dep = UserOrderCrud.order_params()
|
||||||
|
assert getattr(dep, "__name__") == "UserOrderParams"
|
||||||
|
|
||||||
|
def test_raises_when_no_order_fields(self):
|
||||||
|
"""ValueError raised when no order_fields are configured or provided."""
|
||||||
|
with pytest.raises(ValueError, match="no order_fields"):
|
||||||
|
UserCrud.order_params()
|
||||||
|
|
||||||
|
def test_order_fields_override(self):
|
||||||
|
"""order_fields= parameter overrides the class-level default."""
|
||||||
|
UserOrderCrud = CrudFactory(User, order_fields=[User.username, User.email])
|
||||||
|
dep = UserOrderCrud.order_params(order_fields=[User.email])
|
||||||
|
|
||||||
|
param_names = set(inspect.signature(dep).parameters)
|
||||||
|
assert "order_by" in param_names
|
||||||
|
# description should only mention email, not username
|
||||||
|
sig = inspect.signature(dep)
|
||||||
|
description = sig.parameters["order_by"].default.description
|
||||||
|
assert "email" in description
|
||||||
|
assert "username" not in description
|
||||||
|
|
||||||
|
def test_order_by_description_lists_valid_fields(self):
|
||||||
|
"""order_by query param description mentions each allowed field."""
|
||||||
|
UserOrderCrud = CrudFactory(User, order_fields=[User.username, User.email])
|
||||||
|
dep = UserOrderCrud.order_params()
|
||||||
|
|
||||||
|
sig = inspect.signature(dep)
|
||||||
|
description = sig.parameters["order_by"].default.description
|
||||||
|
assert "username" in description
|
||||||
|
assert "email" in description
|
||||||
|
|
||||||
|
def test_default_order_reflected_in_order_default(self):
|
||||||
|
"""default_order is used as the default value for order."""
|
||||||
|
UserOrderCrud = CrudFactory(User, order_fields=[User.username])
|
||||||
|
dep_asc = UserOrderCrud.order_params(default_order="asc")
|
||||||
|
dep_desc = UserOrderCrud.order_params(default_order="desc")
|
||||||
|
|
||||||
|
sig_asc = inspect.signature(dep_asc)
|
||||||
|
sig_desc = inspect.signature(dep_desc)
|
||||||
|
assert sig_asc.parameters["order"].default.default == "asc"
|
||||||
|
assert sig_desc.parameters["order"].default.default == "desc"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_no_order_by_no_default_returns_none(self):
|
||||||
|
"""Returns None when order_by is absent and no default_field is set."""
|
||||||
|
UserOrderCrud = CrudFactory(User, order_fields=[User.username])
|
||||||
|
dep = UserOrderCrud.order_params()
|
||||||
|
result = await dep(order_by=None, order="asc")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_no_order_by_with_default_field_returns_asc_expression(self):
|
||||||
|
"""Returns default_field.asc() when order_by absent and order=asc."""
|
||||||
|
UserOrderCrud = CrudFactory(User, order_fields=[User.username])
|
||||||
|
dep = UserOrderCrud.order_params(default_field=User.username)
|
||||||
|
result = await dep(order_by=None, order="asc")
|
||||||
|
assert isinstance(result, UnaryExpression)
|
||||||
|
assert "ASC" in str(result)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_no_order_by_with_default_field_returns_desc_expression(self):
|
||||||
|
"""Returns default_field.desc() when order_by absent and order=desc."""
|
||||||
|
UserOrderCrud = CrudFactory(User, order_fields=[User.username])
|
||||||
|
dep = UserOrderCrud.order_params(default_field=User.username)
|
||||||
|
result = await dep(order_by=None, order="desc")
|
||||||
|
assert isinstance(result, UnaryExpression)
|
||||||
|
assert "DESC" in str(result)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_valid_order_by_asc(self):
|
||||||
|
"""Returns field.asc() for a valid order_by with order=asc."""
|
||||||
|
UserOrderCrud = CrudFactory(User, order_fields=[User.username])
|
||||||
|
dep = UserOrderCrud.order_params()
|
||||||
|
result = await dep(order_by="username", order="asc")
|
||||||
|
assert isinstance(result, UnaryExpression)
|
||||||
|
assert "ASC" in str(result)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_valid_order_by_desc(self):
|
||||||
|
"""Returns field.desc() for a valid order_by with order=desc."""
|
||||||
|
UserOrderCrud = CrudFactory(User, order_fields=[User.username])
|
||||||
|
dep = UserOrderCrud.order_params()
|
||||||
|
result = await dep(order_by="username", order="desc")
|
||||||
|
assert isinstance(result, UnaryExpression)
|
||||||
|
assert "DESC" in str(result)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_invalid_order_by_raises_invalid_order_field_error(self):
|
||||||
|
"""Raises InvalidOrderFieldError for an unknown order_by value."""
|
||||||
|
UserOrderCrud = CrudFactory(User, order_fields=[User.username])
|
||||||
|
dep = UserOrderCrud.order_params()
|
||||||
|
with pytest.raises(InvalidOrderFieldError) as exc_info:
|
||||||
|
await dep(order_by="nonexistent", order="asc")
|
||||||
|
assert exc_info.value.field == "nonexistent"
|
||||||
|
assert "username" in exc_info.value.valid_fields
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_multiple_fields_all_resolve(self):
|
||||||
|
"""All configured fields resolve correctly via order_by."""
|
||||||
|
UserOrderCrud = CrudFactory(User, order_fields=[User.username, User.email])
|
||||||
|
dep = UserOrderCrud.order_params()
|
||||||
|
result_username = await dep(order_by="username", order="asc")
|
||||||
|
result_email = await dep(order_by="email", order="desc")
|
||||||
|
assert isinstance(result_username, ColumnElement)
|
||||||
|
assert isinstance(result_email, ColumnElement)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_order_params_integrates_with_get_multi(
|
||||||
|
self, db_session: AsyncSession
|
||||||
|
):
|
||||||
|
"""order_params output is accepted by get_multi(order_by=...)."""
|
||||||
|
UserOrderCrud = CrudFactory(User, order_fields=[User.username])
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="charlie", email="c@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="a@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
dep = UserOrderCrud.order_params()
|
||||||
|
order_by = await dep(order_by="username", order="asc")
|
||||||
|
results = await UserOrderCrud.get_multi(db_session, order_by=order_by)
|
||||||
|
|
||||||
|
assert results[0].username == "alice"
|
||||||
|
assert results[1].username == "charlie"
|
||||||
|
|||||||
102
tests/test_db.py
102
tests/test_db.py
@@ -1,5 +1,8 @@
|
|||||||
"""Tests for fastapi_toolsets.db module."""
|
"""Tests for fastapi_toolsets.db module."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
@@ -9,6 +12,7 @@ from fastapi_toolsets.db import (
|
|||||||
create_db_dependency,
|
create_db_dependency,
|
||||||
get_transaction,
|
get_transaction,
|
||||||
lock_tables,
|
lock_tables,
|
||||||
|
wait_for_row_change,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User
|
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User
|
||||||
@@ -241,3 +245,101 @@ class TestLockTables:
|
|||||||
|
|
||||||
result = await RoleCrud.first(db_session, [Role.name == "lock_rollback_role"])
|
result = await RoleCrud.first(db_session, [Role.name == "lock_rollback_role"])
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestWaitForRowChange:
|
||||||
|
"""Tests for wait_for_row_change polling function."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_detects_update(self, db_session: AsyncSession, engine):
|
||||||
|
"""Returns updated instance when a column value changes."""
|
||||||
|
role = Role(name="watch_role")
|
||||||
|
db_session.add(role)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
async def update_later():
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
async with factory() as other:
|
||||||
|
r = await other.get(Role, role.id)
|
||||||
|
assert r is not None
|
||||||
|
r.name = "updated_role"
|
||||||
|
await other.commit()
|
||||||
|
|
||||||
|
update_task = asyncio.create_task(update_later())
|
||||||
|
result = await wait_for_row_change(db_session, Role, role.id, interval=0.05)
|
||||||
|
await update_task
|
||||||
|
|
||||||
|
assert result.name == "updated_role"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_watches_specific_columns(self, db_session: AsyncSession, engine):
|
||||||
|
"""Only triggers on changes to specified columns."""
|
||||||
|
user = User(username="testuser", email="test@example.com")
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
async def update_later():
|
||||||
|
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
# First: change email (not watched) — should not trigger
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
async with factory() as other:
|
||||||
|
u = await other.get(User, user.id)
|
||||||
|
assert u is not None
|
||||||
|
u.email = "new@example.com"
|
||||||
|
await other.commit()
|
||||||
|
# Second: change username (watched) — should trigger
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
async with factory() as other:
|
||||||
|
u = await other.get(User, user.id)
|
||||||
|
assert u is not None
|
||||||
|
u.username = "newuser"
|
||||||
|
await other.commit()
|
||||||
|
|
||||||
|
update_task = asyncio.create_task(update_later())
|
||||||
|
result = await wait_for_row_change(
|
||||||
|
db_session, User, user.id, columns=["username"], interval=0.05
|
||||||
|
)
|
||||||
|
await update_task
|
||||||
|
|
||||||
|
assert result.username == "newuser"
|
||||||
|
assert result.email == "new@example.com"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_nonexistent_row_raises(self, db_session: AsyncSession):
|
||||||
|
"""Raises LookupError when the row does not exist."""
|
||||||
|
fake_id = uuid.uuid4()
|
||||||
|
with pytest.raises(LookupError, match="not found"):
|
||||||
|
await wait_for_row_change(db_session, Role, fake_id, interval=0.05)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_timeout_raises(self, db_session: AsyncSession):
|
||||||
|
"""Raises TimeoutError when no change is detected within timeout."""
|
||||||
|
role = Role(name="timeout_role")
|
||||||
|
db_session.add(role)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
with pytest.raises(TimeoutError):
|
||||||
|
await wait_for_row_change(
|
||||||
|
db_session, Role, role.id, interval=0.05, timeout=0.2
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_deleted_row_raises(self, db_session: AsyncSession, engine):
|
||||||
|
"""Raises LookupError when the row is deleted during polling."""
|
||||||
|
role = Role(name="delete_role")
|
||||||
|
db_session.add(role)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
async def delete_later():
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
async with factory() as other:
|
||||||
|
r = await other.get(Role, role.id)
|
||||||
|
await other.delete(r)
|
||||||
|
await other.commit()
|
||||||
|
|
||||||
|
delete_task = asyncio.create_task(delete_later())
|
||||||
|
with pytest.raises(LookupError):
|
||||||
|
await wait_for_row_change(db_session, Role, role.id, interval=0.05)
|
||||||
|
await delete_task
|
||||||
|
|||||||
186
tests/test_dependencies.py
Normal file
186
tests/test_dependencies.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
"""Tests for fastapi_toolsets.dependencies module."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import uuid
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.params import Depends
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from fastapi_toolsets.dependencies import BodyDependency, PathDependency
|
||||||
|
|
||||||
|
from .conftest import Role, RoleCreate, RoleCrud, User
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""Mock session dependency for testing."""
|
||||||
|
yield None
|
||||||
|
|
||||||
|
|
||||||
|
class TestPathDependency:
|
||||||
|
"""Tests for PathDependency factory."""
|
||||||
|
|
||||||
|
def test_returns_depends_instance(self):
|
||||||
|
"""PathDependency returns a Depends instance."""
|
||||||
|
dep = PathDependency(Role, Role.id, session_dep=mock_get_db)
|
||||||
|
assert isinstance(dep, Depends)
|
||||||
|
|
||||||
|
def test_signature_has_default_param_name(self):
|
||||||
|
"""PathDependency uses model_field as default param name."""
|
||||||
|
dep = cast(Any, PathDependency(Role, Role.id, session_dep=mock_get_db))
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
params = list(sig.parameters.keys())
|
||||||
|
|
||||||
|
assert "role_id" in params
|
||||||
|
assert "session" in params
|
||||||
|
|
||||||
|
def test_signature_has_correct_type_annotation(self):
|
||||||
|
"""PathDependency uses field's python type for annotation."""
|
||||||
|
dep = cast(Any, PathDependency(Role, Role.id, session_dep=mock_get_db))
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
|
assert sig.parameters["role_id"].annotation == uuid.UUID
|
||||||
|
assert sig.parameters["session"].annotation == AsyncSession
|
||||||
|
|
||||||
|
def test_signature_session_has_depends_default(self):
|
||||||
|
"""PathDependency session param has Depends as default."""
|
||||||
|
dep = cast(Any, PathDependency(Role, Role.id, session_dep=mock_get_db))
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
|
assert isinstance(sig.parameters["session"].default, Depends)
|
||||||
|
|
||||||
|
def test_custom_param_name_in_signature(self):
|
||||||
|
"""PathDependency uses custom param_name in signature."""
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
PathDependency(
|
||||||
|
Role, Role.id, session_dep=mock_get_db, param_name="role_uuid"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
params = list(sig.parameters.keys())
|
||||||
|
|
||||||
|
assert "role_uuid" in params
|
||||||
|
assert "id" not in params
|
||||||
|
|
||||||
|
def test_string_field_type(self):
|
||||||
|
"""PathDependency handles string field types."""
|
||||||
|
dep = cast(Any, PathDependency(User, User.username, session_dep=mock_get_db))
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
|
assert sig.parameters["user_username"].annotation is str
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_dependency_fetches_object(self, db_session):
|
||||||
|
"""PathDependency inner function fetches object from database."""
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="test_role"))
|
||||||
|
|
||||||
|
dep = cast(Any, PathDependency(Role, Role.id, session_dep=mock_get_db))
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
result = await func(session=db_session, role_id=role.id)
|
||||||
|
|
||||||
|
assert result.id == role.id
|
||||||
|
assert result.name == "test_role"
|
||||||
|
|
||||||
|
|
||||||
|
class TestBodyDependency:
|
||||||
|
"""Tests for BodyDependency factory."""
|
||||||
|
|
||||||
|
def test_returns_depends_instance(self):
|
||||||
|
"""BodyDependency returns a Depends instance."""
|
||||||
|
dep = BodyDependency(
|
||||||
|
Role, Role.id, session_dep=mock_get_db, body_field="role_id"
|
||||||
|
)
|
||||||
|
assert isinstance(dep, Depends)
|
||||||
|
|
||||||
|
def test_signature_has_body_field_as_param(self):
|
||||||
|
"""BodyDependency uses body_field as param name."""
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
BodyDependency(
|
||||||
|
Role, Role.id, session_dep=mock_get_db, body_field="role_id"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
params = list(sig.parameters.keys())
|
||||||
|
|
||||||
|
assert "role_id" in params
|
||||||
|
assert "session" in params
|
||||||
|
|
||||||
|
def test_signature_has_correct_type_annotation(self):
|
||||||
|
"""BodyDependency uses field's python type for annotation."""
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
BodyDependency(
|
||||||
|
Role, Role.id, session_dep=mock_get_db, body_field="role_id"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
|
assert sig.parameters["role_id"].annotation == uuid.UUID
|
||||||
|
assert sig.parameters["session"].annotation == AsyncSession
|
||||||
|
|
||||||
|
def test_signature_session_has_depends_default(self):
|
||||||
|
"""BodyDependency session param has Depends as default."""
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
BodyDependency(
|
||||||
|
Role, Role.id, session_dep=mock_get_db, body_field="role_id"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
|
assert isinstance(sig.parameters["session"].default, Depends)
|
||||||
|
|
||||||
|
def test_different_body_field_name(self):
|
||||||
|
"""BodyDependency can use any body_field name."""
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
BodyDependency(
|
||||||
|
User, User.id, session_dep=mock_get_db, body_field="user_uuid"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
params = list(sig.parameters.keys())
|
||||||
|
|
||||||
|
assert "user_uuid" in params
|
||||||
|
assert "id" not in params
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_dependency_fetches_object(self, db_session):
|
||||||
|
"""BodyDependency inner function fetches object from database."""
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="body_test_role"))
|
||||||
|
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
BodyDependency(
|
||||||
|
Role, Role.id, session_dep=mock_get_db, body_field="role_id"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
result = await func(session=db_session, role_id=role.id)
|
||||||
|
|
||||||
|
assert result.id == role.id
|
||||||
|
assert result.name == "body_test_role"
|
||||||
395
tests/test_example_pagination_search.py
Normal file
395
tests/test_example_pagination_search.py
Normal file
@@ -0,0 +1,395 @@
|
|||||||
|
"""Live test for the docs/examples/pagination-search.md example.
|
||||||
|
|
||||||
|
Spins up the exact FastAPI app described in the example (sourced from
|
||||||
|
docs_src/examples/pagination_search/) and exercises it through a real HTTP
|
||||||
|
client against a real PostgreSQL database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from docs_src.examples.pagination_search.db import get_db
|
||||||
|
from docs_src.examples.pagination_search.models import Article, Base, Category
|
||||||
|
from docs_src.examples.pagination_search.routes import router
|
||||||
|
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
||||||
|
|
||||||
|
from .conftest import DATABASE_URL
|
||||||
|
|
||||||
|
|
||||||
|
def build_app(session: AsyncSession) -> FastAPI:
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
async def override_get_db():
|
||||||
|
yield session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
app.include_router(router)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
async def ex_db_session():
|
||||||
|
"""Isolated session for the example models (separate tables from conftest)."""
|
||||||
|
engine = create_async_engine(DATABASE_URL, echo=False)
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
session = session_factory()
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.drop_all)
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def client(ex_db_session: AsyncSession):
|
||||||
|
app = build_app(ex_db_session)
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=app), base_url="http://test"
|
||||||
|
) as ac:
|
||||||
|
yield ac
|
||||||
|
|
||||||
|
|
||||||
|
async def seed(session: AsyncSession):
|
||||||
|
"""Insert representative fixture data."""
|
||||||
|
python = Category(name="python")
|
||||||
|
backend = Category(name="backend")
|
||||||
|
session.add_all([python, backend])
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
now = datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc)
|
||||||
|
session.add_all(
|
||||||
|
[
|
||||||
|
Article(
|
||||||
|
title="FastAPI tips",
|
||||||
|
body="Ten useful tips for FastAPI.",
|
||||||
|
status="published",
|
||||||
|
published=True,
|
||||||
|
category_id=python.id,
|
||||||
|
created_at=now,
|
||||||
|
),
|
||||||
|
Article(
|
||||||
|
title="SQLAlchemy async",
|
||||||
|
body="How to use async SQLAlchemy.",
|
||||||
|
status="published",
|
||||||
|
published=True,
|
||||||
|
category_id=backend.id,
|
||||||
|
created_at=now + datetime.timedelta(seconds=1),
|
||||||
|
),
|
||||||
|
Article(
|
||||||
|
title="Draft notes",
|
||||||
|
body="Work in progress.",
|
||||||
|
status="draft",
|
||||||
|
published=False,
|
||||||
|
category_id=None,
|
||||||
|
created_at=now + datetime.timedelta(seconds=2),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
class TestAppSessionDep:
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_get_db_yields_async_session(self):
|
||||||
|
"""get_db yields a real AsyncSession when called directly."""
|
||||||
|
from docs_src.examples.pagination_search.db import get_db
|
||||||
|
|
||||||
|
gen = get_db()
|
||||||
|
session = await gen.__anext__()
|
||||||
|
assert isinstance(session, AsyncSession)
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
|
class TestOffsetPagination:
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_returns_all_articles(self, client: AsyncClient, ex_db_session):
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/offset")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["pagination"]["total_count"] == 3
|
||||||
|
assert len(body["data"]) == 3
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_pagination_page_size(self, client: AsyncClient, ex_db_session):
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/offset?items_per_page=2&page=1")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert len(body["data"]) == 2
|
||||||
|
assert body["pagination"]["total_count"] == 3
|
||||||
|
assert body["pagination"]["has_more"] is True
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_full_text_search(self, client: AsyncClient, ex_db_session):
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/offset?search=fastapi")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["pagination"]["total_count"] == 1
|
||||||
|
assert body["data"][0]["title"] == "FastAPI tips"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_traverses_relationship(
|
||||||
|
self, client: AsyncClient, ex_db_session
|
||||||
|
):
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
# "python" matches Category.name, not Article.title or body
|
||||||
|
resp = await client.get("/articles/offset?search=python")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["pagination"]["total_count"] == 1
|
||||||
|
assert body["data"][0]["title"] == "FastAPI tips"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_facet_filter_scalar(self, client: AsyncClient, ex_db_session):
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/offset?status=published")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["pagination"]["total_count"] == 2
|
||||||
|
assert all(a["status"] == "published" for a in body["data"])
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_facet_filter_multi_value(self, client: AsyncClient, ex_db_session):
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/offset?status=published&status=draft")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["pagination"]["total_count"] == 3
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_filter_attributes_in_response(
|
||||||
|
self, client: AsyncClient, ex_db_session
|
||||||
|
):
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/offset")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
fa = body["filter_attributes"]
|
||||||
|
assert set(fa["status"]) == {"draft", "published"}
|
||||||
|
# "name" is unique across all facet fields — no prefix needed
|
||||||
|
assert set(fa["name"]) == {"backend", "python"}
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_filter_attributes_scoped_to_filter(
|
||||||
|
self, client: AsyncClient, ex_db_session
|
||||||
|
):
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/offset?status=published")
|
||||||
|
|
||||||
|
body = resp.json()
|
||||||
|
# draft is filtered out → should not appear in filter_attributes
|
||||||
|
assert "draft" not in body["filter_attributes"]["status"]
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_and_filter_combined(self, client: AsyncClient, ex_db_session):
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/offset?search=async&status=published")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["pagination"]["total_count"] == 1
|
||||||
|
assert body["data"][0]["title"] == "SQLAlchemy async"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCursorPagination:
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_first_page(self, client: AsyncClient, ex_db_session):
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/cursor?items_per_page=2")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert len(body["data"]) == 2
|
||||||
|
assert body["pagination"]["has_more"] is True
|
||||||
|
assert body["pagination"]["next_cursor"] is not None
|
||||||
|
assert body["pagination"]["prev_cursor"] is None
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_second_page(self, client: AsyncClient, ex_db_session):
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
first = await client.get("/articles/cursor?items_per_page=2")
|
||||||
|
next_cursor = first.json()["pagination"]["next_cursor"]
|
||||||
|
|
||||||
|
resp = await client.get(
|
||||||
|
f"/articles/cursor?items_per_page=2&cursor={next_cursor}"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert len(body["data"]) == 1
|
||||||
|
assert body["pagination"]["has_more"] is False
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_facet_filter(self, client: AsyncClient, ex_db_session):
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/cursor?status=draft")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert len(body["data"]) == 1
|
||||||
|
assert body["data"][0]["status"] == "draft"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_full_text_search(self, client: AsyncClient, ex_db_session):
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/cursor?search=sqlalchemy")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert len(body["data"]) == 1
|
||||||
|
assert body["data"][0]["title"] == "SQLAlchemy async"
|
||||||
|
|
||||||
|
|
||||||
|
class TestOffsetSorting:
|
||||||
|
"""Tests for order_by / order query parameters on the offset endpoint."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_default_order_uses_created_at_asc(
|
||||||
|
self, client: AsyncClient, ex_db_session
|
||||||
|
):
|
||||||
|
"""No order_by → default field (created_at) ASC."""
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/offset")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
titles = [a["title"] for a in resp.json()["data"]]
|
||||||
|
assert titles == ["FastAPI tips", "SQLAlchemy async", "Draft notes"]
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_order_by_title_asc(self, client: AsyncClient, ex_db_session):
|
||||||
|
"""order_by=title&order=asc returns alphabetical order."""
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/offset?order_by=title&order=asc")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
titles = [a["title"] for a in resp.json()["data"]]
|
||||||
|
assert titles == ["Draft notes", "FastAPI tips", "SQLAlchemy async"]
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_order_by_title_desc(self, client: AsyncClient, ex_db_session):
|
||||||
|
"""order_by=title&order=desc returns reverse alphabetical order."""
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/offset?order_by=title&order=desc")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
titles = [a["title"] for a in resp.json()["data"]]
|
||||||
|
assert titles == ["SQLAlchemy async", "FastAPI tips", "Draft notes"]
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_order_by_created_at_desc(self, client: AsyncClient, ex_db_session):
|
||||||
|
"""order_by=created_at&order=desc returns newest-first."""
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/offset?order_by=created_at&order=desc")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
titles = [a["title"] for a in resp.json()["data"]]
|
||||||
|
assert titles == ["Draft notes", "SQLAlchemy async", "FastAPI tips"]
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_invalid_order_by_returns_422(
|
||||||
|
self, client: AsyncClient, ex_db_session
|
||||||
|
):
|
||||||
|
"""Unknown order_by field returns 422 with SORT-422 error code."""
|
||||||
|
resp = await client.get("/articles/offset?order_by=nonexistent_field")
|
||||||
|
|
||||||
|
assert resp.status_code == 422
|
||||||
|
body = resp.json()
|
||||||
|
assert body["error_code"] == "SORT-422"
|
||||||
|
assert body["status"] == "FAIL"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCursorSorting:
|
||||||
|
"""Tests for order_by / order query parameters on the cursor endpoint.
|
||||||
|
|
||||||
|
In cursor_paginate the cursor_column is always the primary sort; order_by
|
||||||
|
acts as a secondary tiebreaker. With the seeded articles (all having unique
|
||||||
|
created_at values) the overall ordering is always created_at ASC regardless
|
||||||
|
of the order_by value — only the valid/invalid field check and the response
|
||||||
|
shape are meaningful here.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_default_order_uses_created_at_asc(
|
||||||
|
self, client: AsyncClient, ex_db_session
|
||||||
|
):
|
||||||
|
"""No order_by → default field (created_at) ASC."""
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/cursor")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
titles = [a["title"] for a in resp.json()["data"]]
|
||||||
|
assert titles == ["FastAPI tips", "SQLAlchemy async", "Draft notes"]
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_order_by_title_asc_accepted(
|
||||||
|
self, client: AsyncClient, ex_db_session
|
||||||
|
):
|
||||||
|
"""order_by=title is a valid field — request succeeds and returns all articles."""
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/cursor?order_by=title&order=asc")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert len(resp.json()["data"]) == 3
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_order_by_title_desc_accepted(
|
||||||
|
self, client: AsyncClient, ex_db_session
|
||||||
|
):
|
||||||
|
"""order_by=title&order=desc is valid — request succeeds and returns all articles."""
|
||||||
|
await seed(ex_db_session)
|
||||||
|
|
||||||
|
resp = await client.get("/articles/cursor?order_by=title&order=desc")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert len(resp.json()["data"]) == 3
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_invalid_order_by_returns_422(
|
||||||
|
self, client: AsyncClient, ex_db_session
|
||||||
|
):
|
||||||
|
"""Unknown order_by field returns 422 with SORT-422 error code."""
|
||||||
|
resp = await client.get("/articles/cursor?order_by=nonexistent_field")
|
||||||
|
|
||||||
|
assert resp.status_code == 422
|
||||||
|
body = resp.json()
|
||||||
|
assert body["error_code"] == "SORT-422"
|
||||||
|
assert body["status"] == "FAIL"
|
||||||
@@ -8,6 +8,7 @@ from fastapi_toolsets.exceptions import (
|
|||||||
ApiException,
|
ApiException,
|
||||||
ConflictError,
|
ConflictError,
|
||||||
ForbiddenError,
|
ForbiddenError,
|
||||||
|
InvalidOrderFieldError,
|
||||||
NotFoundError,
|
NotFoundError,
|
||||||
UnauthorizedError,
|
UnauthorizedError,
|
||||||
generate_error_responses,
|
generate_error_responses,
|
||||||
@@ -108,6 +109,24 @@ class TestGenerateErrorResponses:
|
|||||||
assert example["status"] == "FAIL"
|
assert example["status"] == "FAIL"
|
||||||
assert example["error_code"] == "RES-404"
|
assert example["error_code"] == "RES-404"
|
||||||
assert example["message"] == "Not Found"
|
assert example["message"] == "Not Found"
|
||||||
|
assert example["data"] is None
|
||||||
|
|
||||||
|
def test_response_example_with_data(self):
|
||||||
|
"""Generated response includes data when set on ApiError."""
|
||||||
|
|
||||||
|
class ErrorWithData(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400,
|
||||||
|
msg="Bad Request",
|
||||||
|
desc="Invalid input.",
|
||||||
|
err_code="BAD-400",
|
||||||
|
data={"details": "some context"},
|
||||||
|
)
|
||||||
|
|
||||||
|
responses = generate_error_responses(ErrorWithData)
|
||||||
|
example = responses[400]["content"]["application/json"]["example"]
|
||||||
|
|
||||||
|
assert example["data"] == {"details": "some context"}
|
||||||
|
|
||||||
|
|
||||||
class TestInitExceptionsHandlers:
|
class TestInitExceptionsHandlers:
|
||||||
@@ -137,6 +156,59 @@ class TestInitExceptionsHandlers:
|
|||||||
assert data["error_code"] == "RES-404"
|
assert data["error_code"] == "RES-404"
|
||||||
assert data["message"] == "Not Found"
|
assert data["message"] == "Not Found"
|
||||||
|
|
||||||
|
def test_handles_api_exception_without_data(self):
|
||||||
|
"""ApiException without data returns null data field."""
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
@app.get("/error")
|
||||||
|
async def raise_error():
|
||||||
|
raise NotFoundError()
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/error")
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
assert response.json()["data"] is None
|
||||||
|
|
||||||
|
def test_handles_api_exception_with_data(self):
|
||||||
|
"""ApiException with data returns the data payload."""
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
class CustomValidationError(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=422,
|
||||||
|
msg="Validation Error",
|
||||||
|
desc="1 validation error(s) detected",
|
||||||
|
err_code="CUSTOM-422",
|
||||||
|
data={
|
||||||
|
"errors": [
|
||||||
|
{
|
||||||
|
"field": "email",
|
||||||
|
"message": "invalid format",
|
||||||
|
"type": "value_error",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get("/error")
|
||||||
|
async def raise_error():
|
||||||
|
raise CustomValidationError()
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/error")
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
data = response.json()
|
||||||
|
assert data["data"] == {
|
||||||
|
"errors": [
|
||||||
|
{"field": "email", "message": "invalid format", "type": "value_error"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
assert data["error_code"] == "CUSTOM-422"
|
||||||
|
|
||||||
def test_handles_validation_error(self):
|
def test_handles_validation_error(self):
|
||||||
"""Handles validation errors with structured response."""
|
"""Handles validation errors with structured response."""
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -263,3 +335,43 @@ class TestExceptionIntegration:
|
|||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == {"id": 1}
|
assert response.json() == {"id": 1}
|
||||||
|
|
||||||
|
|
||||||
|
class TestInvalidOrderFieldError:
|
||||||
|
"""Tests for InvalidOrderFieldError exception."""
|
||||||
|
|
||||||
|
def test_api_error_attributes(self):
|
||||||
|
"""InvalidOrderFieldError has correct api_error metadata."""
|
||||||
|
assert InvalidOrderFieldError.api_error.code == 422
|
||||||
|
assert InvalidOrderFieldError.api_error.err_code == "SORT-422"
|
||||||
|
assert InvalidOrderFieldError.api_error.msg == "Invalid Order Field"
|
||||||
|
|
||||||
|
def test_stores_field_and_valid_fields(self):
|
||||||
|
"""InvalidOrderFieldError stores field and valid_fields on the instance."""
|
||||||
|
error = InvalidOrderFieldError("unknown", ["name", "created_at"])
|
||||||
|
assert error.field == "unknown"
|
||||||
|
assert error.valid_fields == ["name", "created_at"]
|
||||||
|
|
||||||
|
def test_message_contains_field_and_valid_fields(self):
|
||||||
|
"""Exception message mentions the bad field and valid options."""
|
||||||
|
error = InvalidOrderFieldError("bad_field", ["name", "email"])
|
||||||
|
assert "bad_field" in str(error)
|
||||||
|
assert "name" in str(error)
|
||||||
|
assert "email" in str(error)
|
||||||
|
|
||||||
|
def test_handled_as_422_by_exception_handler(self):
|
||||||
|
"""init_exceptions_handlers turns InvalidOrderFieldError into a 422 response."""
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
@app.get("/items")
|
||||||
|
async def list_items():
|
||||||
|
raise InvalidOrderFieldError("bad", ["name"])
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/items")
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
data = response.json()
|
||||||
|
assert data["error_code"] == "SORT-422"
|
||||||
|
assert data["status"] == "FAIL"
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""Tests for fastapi_toolsets.fixtures module."""
|
"""Tests for fastapi_toolsets.fixtures module."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
@@ -57,20 +59,22 @@ class TestFixtureRegistry:
|
|||||||
def test_register_with_decorator(self):
|
def test_register_with_decorator(self):
|
||||||
"""Register fixture with decorator."""
|
"""Register fixture with decorator."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
def roles():
|
def roles():
|
||||||
return [Role(id=1, name="admin")]
|
return [Role(id=role_id, name="admin")]
|
||||||
|
|
||||||
assert "roles" in [f.name for f in registry.get_all()]
|
assert "roles" in [f.name for f in registry.get_all()]
|
||||||
|
|
||||||
def test_register_with_custom_name(self):
|
def test_register_with_custom_name(self):
|
||||||
"""Register fixture with custom name."""
|
"""Register fixture with custom name."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register(name="custom_roles")
|
@registry.register(name="custom_roles")
|
||||||
def roles():
|
def roles():
|
||||||
return [Role(id=1, name="admin")]
|
return [Role(id=role_id, name="admin")]
|
||||||
|
|
||||||
fixture = registry.get("custom_roles")
|
fixture = registry.get("custom_roles")
|
||||||
assert fixture.name == "custom_roles"
|
assert fixture.name == "custom_roles"
|
||||||
@@ -78,14 +82,23 @@ class TestFixtureRegistry:
|
|||||||
def test_register_with_dependencies(self):
|
def test_register_with_dependencies(self):
|
||||||
"""Register fixture with dependencies."""
|
"""Register fixture with dependencies."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
def roles():
|
def roles():
|
||||||
return [Role(id=1, name="admin")]
|
return [Role(id=role_id, name="admin")]
|
||||||
|
|
||||||
@registry.register(depends_on=["roles"])
|
@registry.register(depends_on=["roles"])
|
||||||
def users():
|
def users():
|
||||||
return [User(id=1, username="admin", email="admin@test.com", role_id=1)]
|
return [
|
||||||
|
User(
|
||||||
|
id=user_id,
|
||||||
|
username="admin",
|
||||||
|
email="admin@test.com",
|
||||||
|
role_id=role_id,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
fixture = registry.get("users")
|
fixture = registry.get("users")
|
||||||
assert fixture.depends_on == ["roles"]
|
assert fixture.depends_on == ["roles"]
|
||||||
@@ -93,10 +106,11 @@ class TestFixtureRegistry:
|
|||||||
def test_register_with_contexts(self):
|
def test_register_with_contexts(self):
|
||||||
"""Register fixture with contexts."""
|
"""Register fixture with contexts."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register(contexts=[Context.TESTING])
|
@registry.register(contexts=[Context.TESTING])
|
||||||
def test_data():
|
def test_data():
|
||||||
return [Role(id=100, name="test")]
|
return [Role(id=role_id, name="test")]
|
||||||
|
|
||||||
fixture = registry.get("test_data")
|
fixture = registry.get("test_data")
|
||||||
assert Context.TESTING.value in fixture.contexts
|
assert Context.TESTING.value in fixture.contexts
|
||||||
@@ -145,6 +159,178 @@ class TestFixtureRegistry:
|
|||||||
assert names == {"test_data"}
|
assert names == {"test_data"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestIncludeRegistry:
|
||||||
|
"""Tests for FixtureRegistry.include_registry method."""
|
||||||
|
|
||||||
|
def test_include_empty_registry(self):
|
||||||
|
"""Include an empty registry does nothing."""
|
||||||
|
main_registry = FixtureRegistry()
|
||||||
|
other_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@main_registry.register
|
||||||
|
def roles():
|
||||||
|
return []
|
||||||
|
|
||||||
|
main_registry.include_registry(other_registry)
|
||||||
|
|
||||||
|
assert len(main_registry.get_all()) == 1
|
||||||
|
|
||||||
|
def test_include_registry_adds_fixtures(self):
|
||||||
|
"""Include registry adds all fixtures from the other registry."""
|
||||||
|
main_registry = FixtureRegistry()
|
||||||
|
other_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@main_registry.register
|
||||||
|
def roles():
|
||||||
|
return []
|
||||||
|
|
||||||
|
@other_registry.register
|
||||||
|
def users():
|
||||||
|
return []
|
||||||
|
|
||||||
|
@other_registry.register
|
||||||
|
def posts():
|
||||||
|
return []
|
||||||
|
|
||||||
|
main_registry.include_registry(other_registry)
|
||||||
|
|
||||||
|
names = {f.name for f in main_registry.get_all()}
|
||||||
|
assert names == {"roles", "users", "posts"}
|
||||||
|
|
||||||
|
def test_include_registry_preserves_dependencies(self):
|
||||||
|
"""Include registry preserves fixture dependencies."""
|
||||||
|
main_registry = FixtureRegistry()
|
||||||
|
other_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@main_registry.register
|
||||||
|
def roles():
|
||||||
|
return []
|
||||||
|
|
||||||
|
@other_registry.register(depends_on=["roles"])
|
||||||
|
def users():
|
||||||
|
return []
|
||||||
|
|
||||||
|
main_registry.include_registry(other_registry)
|
||||||
|
|
||||||
|
fixture = main_registry.get("users")
|
||||||
|
assert fixture.depends_on == ["roles"]
|
||||||
|
|
||||||
|
def test_include_registry_preserves_contexts(self):
|
||||||
|
"""Include registry preserves fixture contexts."""
|
||||||
|
main_registry = FixtureRegistry()
|
||||||
|
other_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@other_registry.register(contexts=[Context.TESTING, Context.DEVELOPMENT])
|
||||||
|
def test_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
main_registry.include_registry(other_registry)
|
||||||
|
|
||||||
|
fixture = main_registry.get("test_data")
|
||||||
|
assert Context.TESTING.value in fixture.contexts
|
||||||
|
assert Context.DEVELOPMENT.value in fixture.contexts
|
||||||
|
|
||||||
|
def test_include_registry_raises_on_duplicate(self):
|
||||||
|
"""Include registry raises ValueError on duplicate fixture names."""
|
||||||
|
main_registry = FixtureRegistry()
|
||||||
|
other_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@main_registry.register(name="roles")
|
||||||
|
def roles_main():
|
||||||
|
return []
|
||||||
|
|
||||||
|
@other_registry.register(name="roles")
|
||||||
|
def roles_other():
|
||||||
|
return []
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="already exists"):
|
||||||
|
main_registry.include_registry(other_registry)
|
||||||
|
|
||||||
|
def test_include_multiple_registries(self):
|
||||||
|
"""Include multiple registries sequentially."""
|
||||||
|
main_registry = FixtureRegistry()
|
||||||
|
dev_registry = FixtureRegistry()
|
||||||
|
test_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@main_registry.register
|
||||||
|
def base():
|
||||||
|
return []
|
||||||
|
|
||||||
|
@dev_registry.register
|
||||||
|
def dev_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
@test_registry.register
|
||||||
|
def test_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
main_registry.include_registry(dev_registry)
|
||||||
|
main_registry.include_registry(test_registry)
|
||||||
|
|
||||||
|
names = {f.name for f in main_registry.get_all()}
|
||||||
|
assert names == {"base", "dev_data", "test_data"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestDefaultContexts:
|
||||||
|
"""Tests for FixtureRegistry default contexts."""
|
||||||
|
|
||||||
|
def test_default_contexts_applied_to_fixtures(self):
|
||||||
|
"""Default contexts are applied when no contexts specified."""
|
||||||
|
registry = FixtureRegistry(contexts=[Context.TESTING])
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def test_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
fixture = registry.get("test_data")
|
||||||
|
assert fixture.contexts == [Context.TESTING.value]
|
||||||
|
|
||||||
|
def test_explicit_contexts_override_default(self):
|
||||||
|
"""Explicit contexts override default contexts."""
|
||||||
|
registry = FixtureRegistry(contexts=[Context.TESTING])
|
||||||
|
|
||||||
|
@registry.register(contexts=[Context.PRODUCTION])
|
||||||
|
def prod_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
fixture = registry.get("prod_data")
|
||||||
|
assert fixture.contexts == [Context.PRODUCTION.value]
|
||||||
|
|
||||||
|
def test_no_default_contexts_uses_base(self):
|
||||||
|
"""Without default contexts, BASE is used."""
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
fixture = registry.get("data")
|
||||||
|
assert fixture.contexts == [Context.BASE.value]
|
||||||
|
|
||||||
|
def test_multiple_default_contexts(self):
|
||||||
|
"""Multiple default contexts are applied."""
|
||||||
|
registry = FixtureRegistry(contexts=[Context.DEVELOPMENT, Context.TESTING])
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def dev_test_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
fixture = registry.get("dev_test_data")
|
||||||
|
assert Context.DEVELOPMENT.value in fixture.contexts
|
||||||
|
assert Context.TESTING.value in fixture.contexts
|
||||||
|
|
||||||
|
def test_default_contexts_with_string_values(self):
|
||||||
|
"""Default contexts work with string values."""
|
||||||
|
registry = FixtureRegistry(contexts=["custom_context"])
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def custom_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
fixture = registry.get("custom_data")
|
||||||
|
assert fixture.contexts == ["custom_context"]
|
||||||
|
|
||||||
|
|
||||||
class TestDependencyResolution:
|
class TestDependencyResolution:
|
||||||
"""Tests for fixture dependency resolution."""
|
"""Tests for fixture dependency resolution."""
|
||||||
|
|
||||||
@@ -244,12 +430,14 @@ class TestLoadFixtures:
|
|||||||
async def test_load_single_fixture(self, db_session: AsyncSession):
|
async def test_load_single_fixture(self, db_session: AsyncSession):
|
||||||
"""Load a single fixture."""
|
"""Load a single fixture."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id_1 = uuid.uuid4()
|
||||||
|
role_id_2 = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
def roles():
|
def roles():
|
||||||
return [
|
return [
|
||||||
Role(id=1, name="admin"),
|
Role(id=role_id_1, name="admin"),
|
||||||
Role(id=2, name="user"),
|
Role(id=role_id_2, name="user"),
|
||||||
]
|
]
|
||||||
|
|
||||||
result = await load_fixtures(db_session, registry, "roles")
|
result = await load_fixtures(db_session, registry, "roles")
|
||||||
@@ -266,14 +454,23 @@ class TestLoadFixtures:
|
|||||||
async def test_load_with_dependencies(self, db_session: AsyncSession):
|
async def test_load_with_dependencies(self, db_session: AsyncSession):
|
||||||
"""Load fixtures with dependencies."""
|
"""Load fixtures with dependencies."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
def roles():
|
def roles():
|
||||||
return [Role(id=1, name="admin")]
|
return [Role(id=role_id, name="admin")]
|
||||||
|
|
||||||
@registry.register(depends_on=["roles"])
|
@registry.register(depends_on=["roles"])
|
||||||
def users():
|
def users():
|
||||||
return [User(id=1, username="admin", email="admin@test.com", role_id=1)]
|
return [
|
||||||
|
User(
|
||||||
|
id=user_id,
|
||||||
|
username="admin",
|
||||||
|
email="admin@test.com",
|
||||||
|
role_id=role_id,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
result = await load_fixtures(db_session, registry, "users")
|
result = await load_fixtures(db_session, registry, "users")
|
||||||
|
|
||||||
@@ -289,10 +486,11 @@ class TestLoadFixtures:
|
|||||||
async def test_load_with_merge_strategy(self, db_session: AsyncSession):
|
async def test_load_with_merge_strategy(self, db_session: AsyncSession):
|
||||||
"""Load fixtures with MERGE strategy updates existing."""
|
"""Load fixtures with MERGE strategy updates existing."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
def roles():
|
def roles():
|
||||||
return [Role(id=1, name="admin")]
|
return [Role(id=role_id, name="admin")]
|
||||||
|
|
||||||
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
|
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
|
||||||
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
|
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
|
||||||
@@ -306,10 +504,11 @@ class TestLoadFixtures:
|
|||||||
async def test_load_with_skip_existing_strategy(self, db_session: AsyncSession):
|
async def test_load_with_skip_existing_strategy(self, db_session: AsyncSession):
|
||||||
"""Load fixtures with SKIP_EXISTING strategy."""
|
"""Load fixtures with SKIP_EXISTING strategy."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
def roles():
|
def roles():
|
||||||
return [Role(id=1, name="original")]
|
return [Role(id=role_id, name="original")]
|
||||||
|
|
||||||
await load_fixtures(
|
await load_fixtures(
|
||||||
db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING
|
db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING
|
||||||
@@ -317,7 +516,7 @@ class TestLoadFixtures:
|
|||||||
|
|
||||||
@registry.register(name="roles_updated")
|
@registry.register(name="roles_updated")
|
||||||
def roles_v2():
|
def roles_v2():
|
||||||
return [Role(id=1, name="updated")]
|
return [Role(id=role_id, name="updated")]
|
||||||
|
|
||||||
registry._fixtures["roles"] = registry._fixtures.pop("roles_updated")
|
registry._fixtures["roles"] = registry._fixtures.pop("roles_updated")
|
||||||
|
|
||||||
@@ -327,7 +526,7 @@ class TestLoadFixtures:
|
|||||||
|
|
||||||
from .conftest import RoleCrud
|
from .conftest import RoleCrud
|
||||||
|
|
||||||
role = await RoleCrud.first(db_session, [Role.id == 1])
|
role = await RoleCrud.first(db_session, [Role.id == role_id])
|
||||||
assert role is not None
|
assert role is not None
|
||||||
assert role.name == "original"
|
assert role.name == "original"
|
||||||
|
|
||||||
@@ -335,12 +534,14 @@ class TestLoadFixtures:
|
|||||||
async def test_load_with_insert_strategy(self, db_session: AsyncSession):
|
async def test_load_with_insert_strategy(self, db_session: AsyncSession):
|
||||||
"""Load fixtures with INSERT strategy."""
|
"""Load fixtures with INSERT strategy."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id_1 = uuid.uuid4()
|
||||||
|
role_id_2 = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
def roles():
|
def roles():
|
||||||
return [
|
return [
|
||||||
Role(id=1, name="admin"),
|
Role(id=role_id_1, name="admin"),
|
||||||
Role(id=2, name="user"),
|
Role(id=role_id_2, name="user"),
|
||||||
]
|
]
|
||||||
|
|
||||||
result = await load_fixtures(
|
result = await load_fixtures(
|
||||||
@@ -375,14 +576,16 @@ class TestLoadFixtures:
|
|||||||
):
|
):
|
||||||
"""Load multiple independent fixtures."""
|
"""Load multiple independent fixtures."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id_1 = uuid.uuid4()
|
||||||
|
role_id_2 = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
def roles():
|
def roles():
|
||||||
return [Role(id=1, name="admin")]
|
return [Role(id=role_id_1, name="admin")]
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
def other_roles():
|
def other_roles():
|
||||||
return [Role(id=2, name="user")]
|
return [Role(id=role_id_2, name="user")]
|
||||||
|
|
||||||
result = await load_fixtures(db_session, registry, "roles", "other_roles")
|
result = await load_fixtures(db_session, registry, "roles", "other_roles")
|
||||||
|
|
||||||
@@ -402,14 +605,16 @@ class TestLoadFixturesByContext:
|
|||||||
async def test_load_by_single_context(self, db_session: AsyncSession):
|
async def test_load_by_single_context(self, db_session: AsyncSession):
|
||||||
"""Load fixtures by single context."""
|
"""Load fixtures by single context."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
base_role_id = uuid.uuid4()
|
||||||
|
test_role_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register(contexts=[Context.BASE])
|
@registry.register(contexts=[Context.BASE])
|
||||||
def base_roles():
|
def base_roles():
|
||||||
return [Role(id=1, name="base_role")]
|
return [Role(id=base_role_id, name="base_role")]
|
||||||
|
|
||||||
@registry.register(contexts=[Context.TESTING])
|
@registry.register(contexts=[Context.TESTING])
|
||||||
def test_roles():
|
def test_roles():
|
||||||
return [Role(id=100, name="test_role")]
|
return [Role(id=test_role_id, name="test_role")]
|
||||||
|
|
||||||
await load_fixtures_by_context(db_session, registry, Context.BASE)
|
await load_fixtures_by_context(db_session, registry, Context.BASE)
|
||||||
|
|
||||||
@@ -418,7 +623,7 @@ class TestLoadFixturesByContext:
|
|||||||
count = await RoleCrud.count(db_session)
|
count = await RoleCrud.count(db_session)
|
||||||
assert count == 1
|
assert count == 1
|
||||||
|
|
||||||
role = await RoleCrud.first(db_session, [Role.id == 1])
|
role = await RoleCrud.first(db_session, [Role.id == base_role_id])
|
||||||
assert role is not None
|
assert role is not None
|
||||||
assert role.name == "base_role"
|
assert role.name == "base_role"
|
||||||
|
|
||||||
@@ -426,14 +631,16 @@ class TestLoadFixturesByContext:
|
|||||||
async def test_load_by_multiple_contexts(self, db_session: AsyncSession):
|
async def test_load_by_multiple_contexts(self, db_session: AsyncSession):
|
||||||
"""Load fixtures by multiple contexts."""
|
"""Load fixtures by multiple contexts."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
base_role_id = uuid.uuid4()
|
||||||
|
test_role_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register(contexts=[Context.BASE])
|
@registry.register(contexts=[Context.BASE])
|
||||||
def base_roles():
|
def base_roles():
|
||||||
return [Role(id=1, name="base_role")]
|
return [Role(id=base_role_id, name="base_role")]
|
||||||
|
|
||||||
@registry.register(contexts=[Context.TESTING])
|
@registry.register(contexts=[Context.TESTING])
|
||||||
def test_roles():
|
def test_roles():
|
||||||
return [Role(id=100, name="test_role")]
|
return [Role(id=test_role_id, name="test_role")]
|
||||||
|
|
||||||
await load_fixtures_by_context(
|
await load_fixtures_by_context(
|
||||||
db_session, registry, Context.BASE, Context.TESTING
|
db_session, registry, Context.BASE, Context.TESTING
|
||||||
@@ -448,14 +655,23 @@ class TestLoadFixturesByContext:
|
|||||||
async def test_load_context_with_dependencies(self, db_session: AsyncSession):
|
async def test_load_context_with_dependencies(self, db_session: AsyncSession):
|
||||||
"""Load context fixtures with cross-context dependencies."""
|
"""Load context fixtures with cross-context dependencies."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register(contexts=[Context.BASE])
|
@registry.register(contexts=[Context.BASE])
|
||||||
def roles():
|
def roles():
|
||||||
return [Role(id=1, name="admin")]
|
return [Role(id=role_id, name="admin")]
|
||||||
|
|
||||||
@registry.register(depends_on=["roles"], contexts=[Context.TESTING])
|
@registry.register(depends_on=["roles"], contexts=[Context.TESTING])
|
||||||
def test_users():
|
def test_users():
|
||||||
return [User(id=1, username="tester", email="test@test.com", role_id=1)]
|
return [
|
||||||
|
User(
|
||||||
|
id=user_id,
|
||||||
|
username="tester",
|
||||||
|
email="test@test.com",
|
||||||
|
role_id=role_id,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
await load_fixtures_by_context(db_session, registry, Context.TESTING)
|
await load_fixtures_by_context(db_session, registry, Context.TESTING)
|
||||||
|
|
||||||
@@ -471,20 +687,41 @@ class TestGetObjByAttr:
|
|||||||
def setup_method(self):
|
def setup_method(self):
|
||||||
"""Set up test fixtures for each test."""
|
"""Set up test fixtures for each test."""
|
||||||
self.registry = FixtureRegistry()
|
self.registry = FixtureRegistry()
|
||||||
|
self.role_id_1 = uuid.uuid4()
|
||||||
|
self.role_id_2 = uuid.uuid4()
|
||||||
|
self.role_id_3 = uuid.uuid4()
|
||||||
|
self.user_id_1 = uuid.uuid4()
|
||||||
|
self.user_id_2 = uuid.uuid4()
|
||||||
|
|
||||||
|
role_id_1 = self.role_id_1
|
||||||
|
role_id_2 = self.role_id_2
|
||||||
|
role_id_3 = self.role_id_3
|
||||||
|
user_id_1 = self.user_id_1
|
||||||
|
user_id_2 = self.user_id_2
|
||||||
|
|
||||||
@self.registry.register
|
@self.registry.register
|
||||||
def roles() -> list[Role]:
|
def roles() -> list[Role]:
|
||||||
return [
|
return [
|
||||||
Role(id=1, name="admin"),
|
Role(id=role_id_1, name="admin"),
|
||||||
Role(id=2, name="user"),
|
Role(id=role_id_2, name="user"),
|
||||||
Role(id=3, name="moderator"),
|
Role(id=role_id_3, name="moderator"),
|
||||||
]
|
]
|
||||||
|
|
||||||
@self.registry.register(depends_on=["roles"])
|
@self.registry.register(depends_on=["roles"])
|
||||||
def users() -> list[User]:
|
def users() -> list[User]:
|
||||||
return [
|
return [
|
||||||
User(id=1, username="alice", email="alice@example.com", role_id=1),
|
User(
|
||||||
User(id=2, username="bob", email="bob@example.com", role_id=1),
|
id=user_id_1,
|
||||||
|
username="alice",
|
||||||
|
email="alice@example.com",
|
||||||
|
role_id=role_id_1,
|
||||||
|
),
|
||||||
|
User(
|
||||||
|
id=user_id_2,
|
||||||
|
username="bob",
|
||||||
|
email="bob@example.com",
|
||||||
|
role_id=role_id_1,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
self.roles = roles
|
self.roles = roles
|
||||||
@@ -492,26 +729,29 @@ class TestGetObjByAttr:
|
|||||||
|
|
||||||
def test_get_by_id(self):
|
def test_get_by_id(self):
|
||||||
"""Get an object by its id attribute."""
|
"""Get an object by its id attribute."""
|
||||||
role = get_obj_by_attr(self.roles, "id", 1)
|
role = get_obj_by_attr(self.roles, "id", self.role_id_1)
|
||||||
assert role.name == "admin"
|
assert role.name == "admin"
|
||||||
|
|
||||||
def test_get_user_by_username(self):
|
def test_get_user_by_username(self):
|
||||||
"""Get a user by username."""
|
"""Get a user by username."""
|
||||||
user = get_obj_by_attr(self.users, "username", "bob")
|
user = get_obj_by_attr(self.users, "username", "bob")
|
||||||
assert user.id == 2
|
assert user.id == self.user_id_2
|
||||||
assert user.email == "bob@example.com"
|
assert user.email == "bob@example.com"
|
||||||
|
|
||||||
def test_returns_first_match(self):
|
def test_returns_first_match(self):
|
||||||
"""Returns the first matching object when multiple could match."""
|
"""Returns the first matching object when multiple could match."""
|
||||||
user = get_obj_by_attr(self.users, "role_id", 1)
|
user = get_obj_by_attr(self.users, "role_id", self.role_id_1)
|
||||||
assert user.username == "alice"
|
assert user.username == "alice"
|
||||||
|
|
||||||
def test_no_match_raises_stop_iteration(self):
|
def test_no_match_raises_stop_iteration(self):
|
||||||
"""Raises StopIteration when no object matches."""
|
"""Raises StopIteration with contextual message when no object matches."""
|
||||||
with pytest.raises(StopIteration):
|
with pytest.raises(
|
||||||
|
StopIteration,
|
||||||
|
match="No object with name=nonexistent found in fixture 'roles'",
|
||||||
|
):
|
||||||
get_obj_by_attr(self.roles, "name", "nonexistent")
|
get_obj_by_attr(self.roles, "name", "nonexistent")
|
||||||
|
|
||||||
def test_no_match_on_wrong_value_type(self):
|
def test_no_match_on_wrong_value_type(self):
|
||||||
"""Raises StopIteration when value type doesn't match."""
|
"""Raises StopIteration when value type doesn't match."""
|
||||||
with pytest.raises(StopIteration):
|
with pytest.raises(StopIteration):
|
||||||
get_obj_by_attr(self.roles, "id", "1")
|
get_obj_by_attr(self.roles, "id", "not-a-uuid")
|
||||||
|
|||||||
229
tests/test_imports.py
Normal file
229
tests/test_imports.py
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
"""Tests for optional dependency import guards."""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from fastapi_toolsets._imports import require_extra
|
||||||
|
|
||||||
|
|
||||||
|
class TestRequireExtra:
|
||||||
|
"""Tests for the require_extra helper."""
|
||||||
|
|
||||||
|
def test_raises_import_error(self):
|
||||||
|
"""require_extra raises ImportError."""
|
||||||
|
with pytest.raises(ImportError):
|
||||||
|
require_extra(package="some_pkg", extra="some_extra")
|
||||||
|
|
||||||
|
def test_error_message_contains_package_name(self):
|
||||||
|
"""Error message mentions the missing package."""
|
||||||
|
with pytest.raises(ImportError, match="'prometheus_client'"):
|
||||||
|
require_extra(package="prometheus_client", extra="metrics")
|
||||||
|
|
||||||
|
def test_error_message_contains_install_instruction(self):
|
||||||
|
"""Error message contains the pip install command."""
|
||||||
|
with pytest.raises(
|
||||||
|
ImportError, match=r"pip install fastapi-toolsets\[metrics\]"
|
||||||
|
):
|
||||||
|
require_extra(package="prometheus_client", extra="metrics")
|
||||||
|
|
||||||
|
|
||||||
|
def _reload_without_package(module_path: str, blocked_packages: list[str]):
|
||||||
|
"""Reload a module while blocking specific package imports.
|
||||||
|
|
||||||
|
Removes the target module and its parents from sys.modules so they
|
||||||
|
get re-imported, and patches builtins.__import__ to raise ImportError
|
||||||
|
for *blocked_packages*.
|
||||||
|
"""
|
||||||
|
# Remove cached modules so they get re-imported
|
||||||
|
to_remove = [
|
||||||
|
key
|
||||||
|
for key in sys.modules
|
||||||
|
if key == module_path or key.startswith(module_path + ".")
|
||||||
|
]
|
||||||
|
saved = {}
|
||||||
|
for key in to_remove:
|
||||||
|
saved[key] = sys.modules.pop(key)
|
||||||
|
|
||||||
|
# Also remove parent package to force re-execution of __init__.py
|
||||||
|
parts = module_path.rsplit(".", 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
parent = parts[0]
|
||||||
|
parent_keys = [
|
||||||
|
key for key in sys.modules if key == parent or key.startswith(parent + ".")
|
||||||
|
]
|
||||||
|
for key in parent_keys:
|
||||||
|
if key not in saved:
|
||||||
|
saved[key] = sys.modules.pop(key)
|
||||||
|
|
||||||
|
original_import = (
|
||||||
|
__builtins__.__import__ if hasattr(__builtins__, "__import__") else __import__
|
||||||
|
)
|
||||||
|
|
||||||
|
def blocking_import(name, *args, **kwargs):
|
||||||
|
for blocked in blocked_packages:
|
||||||
|
if name == blocked or name.startswith(blocked + "."):
|
||||||
|
raise ImportError(f"Mocked: No module named '{name}'")
|
||||||
|
return original_import(name, *args, **kwargs)
|
||||||
|
|
||||||
|
return saved, blocking_import
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetricsImportGuard:
|
||||||
|
"""Tests for metrics module import guard when prometheus_client is missing."""
|
||||||
|
|
||||||
|
def test_registry_imports_without_prometheus(self):
|
||||||
|
"""Metric and MetricsRegistry are importable without prometheus_client."""
|
||||||
|
saved, blocking_import = _reload_without_package(
|
||||||
|
"fastapi_toolsets.metrics", ["prometheus_client"]
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with patch("builtins.__import__", side_effect=blocking_import):
|
||||||
|
mod = importlib.import_module("fastapi_toolsets.metrics")
|
||||||
|
# Registry types should be available (they're stdlib-only)
|
||||||
|
assert hasattr(mod, "Metric")
|
||||||
|
assert hasattr(mod, "MetricsRegistry")
|
||||||
|
finally:
|
||||||
|
# Restore original modules
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key.startswith("fastapi_toolsets.metrics"):
|
||||||
|
sys.modules.pop(key, None)
|
||||||
|
sys.modules.update(saved)
|
||||||
|
|
||||||
|
def test_init_metrics_stub_raises_without_prometheus(self):
|
||||||
|
"""init_metrics raises ImportError when prometheus_client is missing."""
|
||||||
|
saved, blocking_import = _reload_without_package(
|
||||||
|
"fastapi_toolsets.metrics", ["prometheus_client"]
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with patch("builtins.__import__", side_effect=blocking_import):
|
||||||
|
mod = importlib.import_module("fastapi_toolsets.metrics")
|
||||||
|
with pytest.raises(ImportError, match="prometheus_client"):
|
||||||
|
mod.init_metrics(None, None) # type: ignore[arg-type]
|
||||||
|
finally:
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key.startswith("fastapi_toolsets.metrics"):
|
||||||
|
sys.modules.pop(key, None)
|
||||||
|
sys.modules.update(saved)
|
||||||
|
|
||||||
|
def test_init_metrics_works_with_prometheus(self):
|
||||||
|
"""init_metrics is the real function when prometheus_client is available."""
|
||||||
|
from fastapi_toolsets.metrics import init_metrics
|
||||||
|
|
||||||
|
# Should be the real function, not a stub
|
||||||
|
assert init_metrics.__module__ == "fastapi_toolsets.metrics.handler"
|
||||||
|
|
||||||
|
|
||||||
|
class TestPytestImportGuard:
|
||||||
|
"""Tests for pytest module import guard when dependencies are missing."""
|
||||||
|
|
||||||
|
def test_import_raises_without_pytest_package(self):
|
||||||
|
"""Importing fastapi_toolsets.pytest raises when pytest is missing."""
|
||||||
|
saved, blocking_import = _reload_without_package(
|
||||||
|
"fastapi_toolsets.pytest", ["pytest"]
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with patch("builtins.__import__", side_effect=blocking_import):
|
||||||
|
with pytest.raises(ImportError, match="pytest"):
|
||||||
|
importlib.import_module("fastapi_toolsets.pytest")
|
||||||
|
finally:
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key.startswith("fastapi_toolsets.pytest"):
|
||||||
|
sys.modules.pop(key, None)
|
||||||
|
sys.modules.update(saved)
|
||||||
|
|
||||||
|
def test_import_raises_without_httpx(self):
|
||||||
|
"""Importing fastapi_toolsets.pytest raises when httpx is missing."""
|
||||||
|
saved, blocking_import = _reload_without_package(
|
||||||
|
"fastapi_toolsets.pytest", ["httpx"]
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with patch("builtins.__import__", side_effect=blocking_import):
|
||||||
|
with pytest.raises(ImportError, match="httpx"):
|
||||||
|
importlib.import_module("fastapi_toolsets.pytest")
|
||||||
|
finally:
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key.startswith("fastapi_toolsets.pytest"):
|
||||||
|
sys.modules.pop(key, None)
|
||||||
|
sys.modules.update(saved)
|
||||||
|
|
||||||
|
def test_all_exports_available_with_deps(self):
|
||||||
|
"""All expected exports are available when deps are installed."""
|
||||||
|
from fastapi_toolsets.pytest import (
|
||||||
|
cleanup_tables,
|
||||||
|
create_async_client,
|
||||||
|
create_db_session,
|
||||||
|
create_worker_database,
|
||||||
|
register_fixtures,
|
||||||
|
worker_database_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert callable(register_fixtures)
|
||||||
|
assert callable(create_async_client)
|
||||||
|
assert callable(create_db_session)
|
||||||
|
assert callable(create_worker_database)
|
||||||
|
assert callable(worker_database_url)
|
||||||
|
assert callable(cleanup_tables)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCliImportGuard:
|
||||||
|
"""Tests for CLI module import guard when typer is missing."""
|
||||||
|
|
||||||
|
def test_import_raises_without_typer(self):
|
||||||
|
"""Importing cli.app raises when typer is missing."""
|
||||||
|
saved, blocking_import = _reload_without_package(
|
||||||
|
"fastapi_toolsets.cli.app", ["typer"]
|
||||||
|
)
|
||||||
|
# Also remove cli.config since it imports typer too
|
||||||
|
config_keys = [
|
||||||
|
k for k in sys.modules if k.startswith("fastapi_toolsets.cli.config")
|
||||||
|
]
|
||||||
|
for key in config_keys:
|
||||||
|
if key not in saved:
|
||||||
|
saved[key] = sys.modules.pop(key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with patch("builtins.__import__", side_effect=blocking_import):
|
||||||
|
with pytest.raises(ImportError, match="typer"):
|
||||||
|
importlib.import_module("fastapi_toolsets.cli.app")
|
||||||
|
finally:
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key.startswith("fastapi_toolsets.cli.app") or key.startswith(
|
||||||
|
"fastapi_toolsets.cli.config"
|
||||||
|
):
|
||||||
|
sys.modules.pop(key, None)
|
||||||
|
sys.modules.update(saved)
|
||||||
|
|
||||||
|
def test_error_message_suggests_cli_extra(self):
|
||||||
|
"""Error message suggests installing the cli extra."""
|
||||||
|
saved, blocking_import = _reload_without_package(
|
||||||
|
"fastapi_toolsets.cli.app", ["typer"]
|
||||||
|
)
|
||||||
|
config_keys = [
|
||||||
|
k for k in sys.modules if k.startswith("fastapi_toolsets.cli.config")
|
||||||
|
]
|
||||||
|
for key in config_keys:
|
||||||
|
if key not in saved:
|
||||||
|
saved[key] = sys.modules.pop(key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with patch("builtins.__import__", side_effect=blocking_import):
|
||||||
|
with pytest.raises(
|
||||||
|
ImportError, match=r"pip install fastapi-toolsets\[cli\]"
|
||||||
|
):
|
||||||
|
importlib.import_module("fastapi_toolsets.cli.app")
|
||||||
|
finally:
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key.startswith("fastapi_toolsets.cli.app") or key.startswith(
|
||||||
|
"fastapi_toolsets.cli.config"
|
||||||
|
):
|
||||||
|
sys.modules.pop(key, None)
|
||||||
|
sys.modules.update(saved)
|
||||||
|
|
||||||
|
def test_async_command_imports_without_typer(self):
|
||||||
|
"""async_command is importable without typer (stdlib only)."""
|
||||||
|
from fastapi_toolsets.cli import async_command
|
||||||
|
|
||||||
|
assert callable(async_command)
|
||||||
118
tests/test_logger.py
Normal file
118
tests/test_logger.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from fastapi_toolsets.logger import (
|
||||||
|
DEFAULT_FORMAT,
|
||||||
|
UVICORN_LOGGERS,
|
||||||
|
configure_logging,
|
||||||
|
get_logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset_loggers():
|
||||||
|
"""Reset the root and uvicorn loggers after each test."""
|
||||||
|
yield
|
||||||
|
root = logging.getLogger()
|
||||||
|
root.handlers.clear()
|
||||||
|
root.setLevel(logging.WARNING)
|
||||||
|
for name in UVICORN_LOGGERS:
|
||||||
|
uv = logging.getLogger(name)
|
||||||
|
uv.handlers.clear()
|
||||||
|
uv.setLevel(logging.NOTSET)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigureLogging:
|
||||||
|
def test_sets_up_handler_and_format(self):
|
||||||
|
logger = configure_logging()
|
||||||
|
|
||||||
|
assert len(logger.handlers) == 1
|
||||||
|
handler = logger.handlers[0]
|
||||||
|
assert isinstance(handler, logging.StreamHandler)
|
||||||
|
assert handler.stream is sys.stdout
|
||||||
|
assert handler.formatter is not None
|
||||||
|
assert handler.formatter._fmt == DEFAULT_FORMAT
|
||||||
|
|
||||||
|
def test_default_level_is_info(self):
|
||||||
|
logger = configure_logging()
|
||||||
|
|
||||||
|
assert logger.level == logging.INFO
|
||||||
|
|
||||||
|
def test_custom_level_string(self):
|
||||||
|
logger = configure_logging(level="DEBUG")
|
||||||
|
|
||||||
|
assert logger.level == logging.DEBUG
|
||||||
|
|
||||||
|
def test_custom_level_int(self):
|
||||||
|
logger = configure_logging(level=logging.WARNING)
|
||||||
|
|
||||||
|
assert logger.level == logging.WARNING
|
||||||
|
|
||||||
|
def test_custom_format(self):
|
||||||
|
custom_fmt = "%(levelname)s: %(message)s"
|
||||||
|
logger = configure_logging(fmt=custom_fmt)
|
||||||
|
|
||||||
|
handler = logger.handlers[0]
|
||||||
|
assert handler.formatter is not None
|
||||||
|
assert handler.formatter._fmt == custom_fmt
|
||||||
|
|
||||||
|
def test_named_logger(self):
|
||||||
|
logger = configure_logging(logger_name="myapp")
|
||||||
|
|
||||||
|
assert logger.name == "myapp"
|
||||||
|
assert len(logger.handlers) == 1
|
||||||
|
|
||||||
|
def test_default_configures_root_logger(self):
|
||||||
|
logger = configure_logging()
|
||||||
|
|
||||||
|
assert logger is logging.getLogger()
|
||||||
|
|
||||||
|
def test_idempotent_no_duplicate_handlers(self):
|
||||||
|
configure_logging()
|
||||||
|
configure_logging()
|
||||||
|
logger = configure_logging()
|
||||||
|
|
||||||
|
assert len(logger.handlers) == 1
|
||||||
|
|
||||||
|
def test_configures_uvicorn_loggers(self):
|
||||||
|
configure_logging(level="DEBUG")
|
||||||
|
|
||||||
|
for name in UVICORN_LOGGERS:
|
||||||
|
uv_logger = logging.getLogger(name)
|
||||||
|
assert len(uv_logger.handlers) == 1
|
||||||
|
assert uv_logger.level == logging.DEBUG
|
||||||
|
handler = uv_logger.handlers[0]
|
||||||
|
assert handler.formatter is not None
|
||||||
|
assert handler.formatter._fmt == DEFAULT_FORMAT
|
||||||
|
|
||||||
|
def test_returns_configured_logger(self):
|
||||||
|
logger = configure_logging(logger_name="test.return")
|
||||||
|
|
||||||
|
assert isinstance(logger, logging.Logger)
|
||||||
|
assert logger.name == "test.return"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetLogger:
|
||||||
|
def test_returns_named_logger(self):
|
||||||
|
logger = get_logger("myapp.services")
|
||||||
|
|
||||||
|
assert isinstance(logger, logging.Logger)
|
||||||
|
assert logger.name == "myapp.services"
|
||||||
|
|
||||||
|
def test_returns_root_logger_when_none(self):
|
||||||
|
logger = get_logger(None)
|
||||||
|
|
||||||
|
assert logger is logging.getLogger()
|
||||||
|
|
||||||
|
def test_defaults_to_caller_module_name(self):
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
assert logger.name == __name__
|
||||||
|
|
||||||
|
def test_same_name_returns_same_logger(self):
|
||||||
|
a = get_logger("myapp")
|
||||||
|
b = get_logger("myapp")
|
||||||
|
|
||||||
|
assert a is b
|
||||||
519
tests/test_metrics.py
Normal file
519
tests/test_metrics.py
Normal file
@@ -0,0 +1,519 @@
|
|||||||
|
"""Tests for fastapi_toolsets.metrics module."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from prometheus_client import REGISTRY, CollectorRegistry, Counter, Gauge
|
||||||
|
|
||||||
|
from fastapi_toolsets.metrics import Metric, MetricsRegistry, init_metrics
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _clean_prometheus_registry():
|
||||||
|
"""Unregister test collectors from the global registry after each test."""
|
||||||
|
yield
|
||||||
|
collectors = list(REGISTRY._names_to_collectors.values())
|
||||||
|
for collector in collectors:
|
||||||
|
try:
|
||||||
|
REGISTRY.unregister(collector)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetric:
|
||||||
|
"""Tests for Metric dataclass."""
|
||||||
|
|
||||||
|
def test_default_collect_is_false(self):
|
||||||
|
"""Default collect is False (provider mode)."""
|
||||||
|
definition = Metric(name="test", func=lambda: None)
|
||||||
|
assert definition.collect is False
|
||||||
|
|
||||||
|
def test_collect_true(self):
|
||||||
|
"""Collect can be set to True (collector mode)."""
|
||||||
|
definition = Metric(name="test", func=lambda: None, collect=True)
|
||||||
|
assert definition.collect is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetricsRegistry:
|
||||||
|
"""Tests for MetricsRegistry class."""
|
||||||
|
|
||||||
|
def test_register_with_decorator(self):
|
||||||
|
"""Register metric with bare decorator."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def my_counter():
|
||||||
|
return Counter("test_counter", "A test counter")
|
||||||
|
|
||||||
|
names = [m.name for m in registry.get_all()]
|
||||||
|
assert "my_counter" in names
|
||||||
|
|
||||||
|
def test_register_with_custom_name(self):
|
||||||
|
"""Register metric with custom name."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register(name="custom_name")
|
||||||
|
def my_counter():
|
||||||
|
return Counter("test_counter_2", "A test counter")
|
||||||
|
|
||||||
|
definition = registry.get_all()[0]
|
||||||
|
assert definition.name == "custom_name"
|
||||||
|
|
||||||
|
def test_register_as_collector(self):
|
||||||
|
"""Register metric with collect=True."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def collect_something():
|
||||||
|
pass
|
||||||
|
|
||||||
|
definition = registry.get_all()[0]
|
||||||
|
assert definition.collect is True
|
||||||
|
|
||||||
|
def test_register_preserves_function(self):
|
||||||
|
"""Decorator returns the original function unchanged."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
def my_func():
|
||||||
|
return "original"
|
||||||
|
|
||||||
|
result = registry.register(my_func)
|
||||||
|
assert result is my_func
|
||||||
|
assert result() == "original"
|
||||||
|
|
||||||
|
def test_register_parameterized_preserves_function(self):
|
||||||
|
"""Parameterized decorator returns the original function unchanged."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
def my_func():
|
||||||
|
return "original"
|
||||||
|
|
||||||
|
result = registry.register(name="custom")(my_func)
|
||||||
|
assert result is my_func
|
||||||
|
assert result() == "original"
|
||||||
|
|
||||||
|
def test_get_all(self):
|
||||||
|
"""Get all registered metrics."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def metric_a():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def metric_b():
|
||||||
|
pass
|
||||||
|
|
||||||
|
names = {m.name for m in registry.get_all()}
|
||||||
|
assert names == {"metric_a", "metric_b"}
|
||||||
|
|
||||||
|
def test_get_providers(self):
|
||||||
|
"""Get only provider metrics (collect=False)."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def provider():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def collector():
|
||||||
|
pass
|
||||||
|
|
||||||
|
providers = registry.get_providers()
|
||||||
|
assert len(providers) == 1
|
||||||
|
assert providers[0].name == "provider"
|
||||||
|
|
||||||
|
def test_get_collectors(self):
|
||||||
|
"""Get only collector metrics (collect=True)."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def provider():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def collector():
|
||||||
|
pass
|
||||||
|
|
||||||
|
collectors = registry.get_collectors()
|
||||||
|
assert len(collectors) == 1
|
||||||
|
assert collectors[0].name == "collector"
|
||||||
|
|
||||||
|
def test_register_overwrites_same_name(self):
|
||||||
|
"""Registering with the same name overwrites the previous entry."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register(name="metric")
|
||||||
|
def first():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@registry.register(name="metric")
|
||||||
|
def second():
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert len(registry.get_all()) == 1
|
||||||
|
assert registry.get_all()[0].func is second
|
||||||
|
|
||||||
|
|
||||||
|
class TestIncludeRegistry:
|
||||||
|
"""Tests for MetricsRegistry.include_registry method."""
|
||||||
|
|
||||||
|
def test_include_empty_registry(self):
|
||||||
|
"""Include an empty registry does nothing."""
|
||||||
|
main = MetricsRegistry()
|
||||||
|
other = MetricsRegistry()
|
||||||
|
|
||||||
|
@main.register
|
||||||
|
def metric_a():
|
||||||
|
pass
|
||||||
|
|
||||||
|
main.include_registry(other)
|
||||||
|
assert len(main.get_all()) == 1
|
||||||
|
|
||||||
|
def test_include_registry_adds_metrics(self):
|
||||||
|
"""Include registry adds all metrics from the other registry."""
|
||||||
|
main = MetricsRegistry()
|
||||||
|
other = MetricsRegistry()
|
||||||
|
|
||||||
|
@main.register
|
||||||
|
def metric_a():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@other.register
|
||||||
|
def metric_b():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@other.register
|
||||||
|
def metric_c():
|
||||||
|
pass
|
||||||
|
|
||||||
|
main.include_registry(other)
|
||||||
|
names = {m.name for m in main.get_all()}
|
||||||
|
assert names == {"metric_a", "metric_b", "metric_c"}
|
||||||
|
|
||||||
|
def test_include_registry_preserves_collect_flag(self):
|
||||||
|
"""Include registry preserves the collect flag."""
|
||||||
|
main = MetricsRegistry()
|
||||||
|
other = MetricsRegistry()
|
||||||
|
|
||||||
|
@other.register(collect=True)
|
||||||
|
def collector():
|
||||||
|
pass
|
||||||
|
|
||||||
|
main.include_registry(other)
|
||||||
|
assert main.get_all()[0].collect is True
|
||||||
|
|
||||||
|
def test_include_registry_raises_on_duplicate(self):
|
||||||
|
"""Include registry raises ValueError on duplicate metric names."""
|
||||||
|
main = MetricsRegistry()
|
||||||
|
other = MetricsRegistry()
|
||||||
|
|
||||||
|
@main.register(name="metric")
|
||||||
|
def metric_main():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@other.register(name="metric")
|
||||||
|
def metric_other():
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="already exists"):
|
||||||
|
main.include_registry(other)
|
||||||
|
|
||||||
|
def test_include_multiple_registries(self):
|
||||||
|
"""Include multiple registries sequentially."""
|
||||||
|
main = MetricsRegistry()
|
||||||
|
sub1 = MetricsRegistry()
|
||||||
|
sub2 = MetricsRegistry()
|
||||||
|
|
||||||
|
@main.register
|
||||||
|
def base():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@sub1.register
|
||||||
|
def sub1_metric():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@sub2.register
|
||||||
|
def sub2_metric():
|
||||||
|
pass
|
||||||
|
|
||||||
|
main.include_registry(sub1)
|
||||||
|
main.include_registry(sub2)
|
||||||
|
|
||||||
|
names = {m.name for m in main.get_all()}
|
||||||
|
assert names == {"base", "sub1_metric", "sub2_metric"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestInitMetrics:
|
||||||
|
"""Tests for init_metrics function."""
|
||||||
|
|
||||||
|
def test_returns_app(self):
|
||||||
|
"""Returns the FastAPI app."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
result = init_metrics(app, registry)
|
||||||
|
assert result is app
|
||||||
|
|
||||||
|
def test_metrics_endpoint_responds(self):
|
||||||
|
"""The /metrics endpoint returns 200."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/metrics")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_metrics_endpoint_content_type(self):
|
||||||
|
"""The /metrics endpoint returns prometheus content type."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/metrics")
|
||||||
|
|
||||||
|
assert "text/plain" in response.headers["content-type"]
|
||||||
|
|
||||||
|
def test_custom_path(self):
|
||||||
|
"""Custom path is used for the metrics endpoint."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
init_metrics(app, registry, path="/custom-metrics")
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
assert client.get("/custom-metrics").status_code == 200
|
||||||
|
assert client.get("/metrics").status_code == 404
|
||||||
|
|
||||||
|
def test_providers_called_at_init(self):
|
||||||
|
"""Provider functions are called once at init time."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
mock = MagicMock()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def my_provider():
|
||||||
|
mock()
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
mock.assert_called_once()
|
||||||
|
|
||||||
|
def test_collectors_called_on_scrape(self):
|
||||||
|
"""Collector functions are called on each scrape."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
mock = MagicMock()
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def my_collector():
|
||||||
|
mock()
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
client.get("/metrics")
|
||||||
|
client.get("/metrics")
|
||||||
|
|
||||||
|
assert mock.call_count == 2
|
||||||
|
|
||||||
|
def test_collectors_not_called_at_init(self):
|
||||||
|
"""Collector functions are not called at init time."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
mock = MagicMock()
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def my_collector():
|
||||||
|
mock()
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
mock.assert_not_called()
|
||||||
|
|
||||||
|
def test_async_collectors_called_on_scrape(self):
|
||||||
|
"""Async collector functions are awaited on each scrape."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
mock = AsyncMock()
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
async def my_async_collector():
|
||||||
|
await mock()
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
client.get("/metrics")
|
||||||
|
client.get("/metrics")
|
||||||
|
|
||||||
|
assert mock.call_count == 2
|
||||||
|
|
||||||
|
def test_mixed_sync_and_async_collectors(self):
|
||||||
|
"""Both sync and async collectors are called on scrape."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
sync_mock = MagicMock()
|
||||||
|
async_mock = AsyncMock()
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def sync_collector():
|
||||||
|
sync_mock()
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
async def async_collector():
|
||||||
|
await async_mock()
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
client.get("/metrics")
|
||||||
|
|
||||||
|
sync_mock.assert_called_once()
|
||||||
|
async_mock.assert_called_once()
|
||||||
|
|
||||||
|
def test_registered_metrics_appear_in_output(self):
|
||||||
|
"""Metrics created by providers appear in /metrics output."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def my_gauge():
|
||||||
|
g = Gauge("test_gauge_value", "A test gauge")
|
||||||
|
g.set(42)
|
||||||
|
return g
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/metrics")
|
||||||
|
|
||||||
|
assert b"test_gauge_value" in response.content
|
||||||
|
assert b"42.0" in response.content
|
||||||
|
|
||||||
|
def test_endpoint_not_in_openapi_schema(self):
|
||||||
|
"""The /metrics endpoint is not included in the OpenAPI schema."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
schema = app.openapi()
|
||||||
|
assert "/metrics" not in schema.get("paths", {})
|
||||||
|
|
||||||
|
|
||||||
|
class TestMultiProcessMode:
|
||||||
|
"""Tests for multi-process Prometheus mode."""
|
||||||
|
|
||||||
|
def test_multiprocess_with_env_var(self):
|
||||||
|
"""Multi-process mode works when PROMETHEUS_MULTIPROC_DIR is set."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
os.environ["PROMETHEUS_MULTIPROC_DIR"] = tmpdir
|
||||||
|
try:
|
||||||
|
# Use a separate registry to avoid conflicts with default
|
||||||
|
prom_registry = CollectorRegistry()
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def mp_counter():
|
||||||
|
return Counter(
|
||||||
|
"mp_test_counter",
|
||||||
|
"A multiprocess counter",
|
||||||
|
registry=prom_registry,
|
||||||
|
)
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/metrics")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
finally:
|
||||||
|
del os.environ["PROMETHEUS_MULTIPROC_DIR"]
|
||||||
|
|
||||||
|
def test_single_process_without_env_var(self):
|
||||||
|
"""Single-process mode when PROMETHEUS_MULTIPROC_DIR is not set."""
|
||||||
|
os.environ.pop("PROMETHEUS_MULTIPROC_DIR", None)
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def sp_gauge():
|
||||||
|
g = Gauge("sp_test_gauge", "A single-process gauge")
|
||||||
|
g.set(99)
|
||||||
|
return g
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/metrics")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert b"sp_test_gauge" in response.content
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetricsIntegration:
|
||||||
|
"""Integration tests for the metrics module."""
|
||||||
|
|
||||||
|
def test_full_workflow(self):
|
||||||
|
"""Full workflow: registry, providers, collectors, endpoint."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
call_count = {"value": 0}
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def request_counter():
|
||||||
|
return Counter(
|
||||||
|
"integration_requests_total",
|
||||||
|
"Total requests",
|
||||||
|
["method"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def collect_uptime():
|
||||||
|
call_count["value"] += 1
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
response = client.get("/metrics")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert b"integration_requests_total" in response.content
|
||||||
|
assert call_count["value"] == 1
|
||||||
|
|
||||||
|
response = client.get("/metrics")
|
||||||
|
assert call_count["value"] == 2
|
||||||
|
|
||||||
|
def test_multiple_registries_merged(self):
|
||||||
|
"""Multiple registries can be merged and used together."""
|
||||||
|
app = FastAPI()
|
||||||
|
main = MetricsRegistry()
|
||||||
|
sub = MetricsRegistry()
|
||||||
|
|
||||||
|
@main.register
|
||||||
|
def main_gauge():
|
||||||
|
g = Gauge("main_gauge_val", "Main gauge")
|
||||||
|
g.set(1)
|
||||||
|
return g
|
||||||
|
|
||||||
|
@sub.register
|
||||||
|
def sub_gauge():
|
||||||
|
g = Gauge("sub_gauge_val", "Sub gauge")
|
||||||
|
g.set(2)
|
||||||
|
return g
|
||||||
|
|
||||||
|
main.include_registry(sub)
|
||||||
|
init_metrics(app, main)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/metrics")
|
||||||
|
|
||||||
|
assert b"main_gauge_val" in response.content
|
||||||
|
assert b"sub_gauge_val" in response.content
|
||||||
@@ -1,44 +1,73 @@
|
|||||||
"""Tests for fastapi_toolsets.pytest module."""
|
"""Tests for fastapi_toolsets.pytest module."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import FastAPI
|
from fastapi import Depends, FastAPI
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select, text
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.engine import make_url
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, selectinload
|
||||||
|
|
||||||
from fastapi_toolsets.fixtures import Context, FixtureRegistry
|
from fastapi_toolsets.fixtures import Context, FixtureRegistry
|
||||||
from fastapi_toolsets.pytest import (
|
from fastapi_toolsets.pytest import (
|
||||||
|
cleanup_tables,
|
||||||
create_async_client,
|
create_async_client,
|
||||||
create_db_session,
|
create_db_session,
|
||||||
|
create_worker_database,
|
||||||
register_fixtures,
|
register_fixtures,
|
||||||
|
worker_database_url,
|
||||||
)
|
)
|
||||||
|
from fastapi_toolsets.pytest.utils import _get_xdist_worker
|
||||||
|
|
||||||
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
|
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
|
||||||
|
|
||||||
test_registry = FixtureRegistry()
|
test_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
# Fixed UUIDs for test fixtures to allow consistent assertions
|
||||||
|
ROLE_ADMIN_ID = uuid.UUID("00000000-0000-0000-0000-000000001000")
|
||||||
|
ROLE_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000001001")
|
||||||
|
USER_ADMIN_ID = uuid.UUID("00000000-0000-0000-0000-000000002000")
|
||||||
|
USER_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000002001")
|
||||||
|
USER_EXTRA_ID = uuid.UUID("00000000-0000-0000-0000-000000002002")
|
||||||
|
|
||||||
|
|
||||||
@test_registry.register(contexts=[Context.BASE])
|
@test_registry.register(contexts=[Context.BASE])
|
||||||
def roles() -> list[Role]:
|
def roles() -> list[Role]:
|
||||||
return [
|
return [
|
||||||
Role(id=1000, name="plugin_admin"),
|
Role(id=ROLE_ADMIN_ID, name="plugin_admin"),
|
||||||
Role(id=1001, name="plugin_user"),
|
Role(id=ROLE_USER_ID, name="plugin_user"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@test_registry.register(depends_on=["roles"], contexts=[Context.BASE])
|
@test_registry.register(depends_on=["roles"], contexts=[Context.BASE])
|
||||||
def users() -> list[User]:
|
def users() -> list[User]:
|
||||||
return [
|
return [
|
||||||
User(id=1000, username="plugin_admin", email="padmin@test.com", role_id=1000),
|
User(
|
||||||
User(id=1001, username="plugin_user", email="puser@test.com", role_id=1001),
|
id=USER_ADMIN_ID,
|
||||||
|
username="plugin_admin",
|
||||||
|
email="padmin@test.com",
|
||||||
|
role_id=ROLE_ADMIN_ID,
|
||||||
|
),
|
||||||
|
User(
|
||||||
|
id=USER_USER_ID,
|
||||||
|
username="plugin_user",
|
||||||
|
email="puser@test.com",
|
||||||
|
role_id=ROLE_USER_ID,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@test_registry.register(depends_on=["users"], contexts=[Context.TESTING])
|
@test_registry.register(depends_on=["users"], contexts=[Context.TESTING])
|
||||||
def extra_users() -> list[User]:
|
def extra_users() -> list[User]:
|
||||||
return [
|
return [
|
||||||
User(id=1002, username="plugin_extra", email="pextra@test.com", role_id=1001),
|
User(
|
||||||
|
id=USER_EXTRA_ID,
|
||||||
|
username="plugin_extra",
|
||||||
|
email="pextra@test.com",
|
||||||
|
role_id=ROLE_USER_ID,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -73,7 +102,7 @@ class TestGeneratedFixtures:
|
|||||||
assert fixture_roles[1].name == "plugin_user"
|
assert fixture_roles[1].name == "plugin_user"
|
||||||
|
|
||||||
# Verify data is in database
|
# Verify data is in database
|
||||||
count = await RoleCrud.count(db_session, [Role.id >= 1000])
|
count = await RoleCrud.count(db_session)
|
||||||
assert count == 2
|
assert count == 2
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -86,11 +115,11 @@ class TestGeneratedFixtures:
|
|||||||
assert len(fixture_users) == 2
|
assert len(fixture_users) == 2
|
||||||
|
|
||||||
# Roles should also be in database
|
# Roles should also be in database
|
||||||
roles_count = await RoleCrud.count(db_session, [Role.id >= 1000])
|
roles_count = await RoleCrud.count(db_session)
|
||||||
assert roles_count == 2
|
assert roles_count == 2
|
||||||
|
|
||||||
# Users should be in database
|
# Users should be in database
|
||||||
users_count = await UserCrud.count(db_session, [User.id >= 1000])
|
users_count = await UserCrud.count(db_session)
|
||||||
assert users_count == 2
|
assert users_count == 2
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -100,7 +129,7 @@ class TestGeneratedFixtures:
|
|||||||
"""Fixture returns actual model instances."""
|
"""Fixture returns actual model instances."""
|
||||||
user = fixture_users[0]
|
user = fixture_users[0]
|
||||||
assert isinstance(user, User)
|
assert isinstance(user, User)
|
||||||
assert user.id == 1000
|
assert user.id == USER_ADMIN_ID
|
||||||
assert user.username == "plugin_admin"
|
assert user.username == "plugin_admin"
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -111,7 +140,7 @@ class TestGeneratedFixtures:
|
|||||||
# Load user with role relationship
|
# Load user with role relationship
|
||||||
user = await UserCrud.get(
|
user = await UserCrud.get(
|
||||||
db_session,
|
db_session,
|
||||||
[User.id == 1000],
|
[User.id == USER_ADMIN_ID],
|
||||||
load_options=[selectinload(User.role)],
|
load_options=[selectinload(User.role)],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -127,8 +156,8 @@ class TestGeneratedFixtures:
|
|||||||
assert len(fixture_extra_users) == 1
|
assert len(fixture_extra_users) == 1
|
||||||
|
|
||||||
# All fixtures should be loaded
|
# All fixtures should be loaded
|
||||||
roles_count = await RoleCrud.count(db_session, [Role.id >= 1000])
|
roles_count = await RoleCrud.count(db_session)
|
||||||
users_count = await UserCrud.count(db_session, [User.id >= 1000])
|
users_count = await UserCrud.count(db_session)
|
||||||
|
|
||||||
assert roles_count == 2
|
assert roles_count == 2
|
||||||
assert users_count == 3 # 2 from users + 1 from extra_users
|
assert users_count == 3 # 2 from users + 1 from extra_users
|
||||||
@@ -141,8 +170,7 @@ class TestGeneratedFixtures:
|
|||||||
# Get all users loaded by fixture
|
# Get all users loaded by fixture
|
||||||
users = await UserCrud.get_multi(
|
users = await UserCrud.get_multi(
|
||||||
db_session,
|
db_session,
|
||||||
filters=[User.id >= 1000],
|
order_by=User.username,
|
||||||
order_by=User.id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(users) == 2
|
assert len(users) == 2
|
||||||
@@ -161,8 +189,8 @@ class TestGeneratedFixtures:
|
|||||||
assert len(fixture_users) == 2
|
assert len(fixture_users) == 2
|
||||||
|
|
||||||
# Both should be in database
|
# Both should be in database
|
||||||
roles = await RoleCrud.get_multi(db_session, filters=[Role.id >= 1000])
|
roles = await RoleCrud.get_multi(db_session)
|
||||||
users = await UserCrud.get_multi(db_session, filters=[User.id >= 1000])
|
users = await UserCrud.get_multi(db_session)
|
||||||
|
|
||||||
assert len(roles) == 2
|
assert len(roles) == 2
|
||||||
assert len(users) == 2
|
assert len(users) == 2
|
||||||
@@ -208,6 +236,30 @@ class TestCreateAsyncClient:
|
|||||||
|
|
||||||
assert client_ref.is_closed
|
assert client_ref.is_closed
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_dependency_overrides_applied_and_cleaned(self):
|
||||||
|
"""Dependency overrides are applied during the context and removed after."""
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
async def original_dep() -> str:
|
||||||
|
return "original"
|
||||||
|
|
||||||
|
async def override_dep() -> str:
|
||||||
|
return "overridden"
|
||||||
|
|
||||||
|
@app.get("/dep")
|
||||||
|
async def dep_endpoint(value: str = Depends(original_dep)):
|
||||||
|
return {"value": value}
|
||||||
|
|
||||||
|
async with create_async_client(
|
||||||
|
app, dependency_overrides={original_dep: override_dep}
|
||||||
|
) as client:
|
||||||
|
response = await client.get("/dep")
|
||||||
|
assert response.json() == {"value": "overridden"}
|
||||||
|
|
||||||
|
# Overrides should be cleaned up
|
||||||
|
assert original_dep not in app.dependency_overrides
|
||||||
|
|
||||||
|
|
||||||
class TestCreateDbSession:
|
class TestCreateDbSession:
|
||||||
"""Tests for create_db_session helper."""
|
"""Tests for create_db_session helper."""
|
||||||
@@ -215,14 +267,15 @@ class TestCreateDbSession:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_creates_working_session(self):
|
async def test_creates_working_session(self):
|
||||||
"""Session can perform database operations."""
|
"""Session can perform database operations."""
|
||||||
|
role_id = uuid.uuid4()
|
||||||
async with create_db_session(DATABASE_URL, Base) as session:
|
async with create_db_session(DATABASE_URL, Base) as session:
|
||||||
assert isinstance(session, AsyncSession)
|
assert isinstance(session, AsyncSession)
|
||||||
|
|
||||||
role = Role(id=9001, name="test_helper_role")
|
role = Role(id=role_id, name="test_helper_role")
|
||||||
session.add(role)
|
session.add(role)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
result = await session.execute(select(Role).where(Role.id == 9001))
|
result = await session.execute(select(Role).where(Role.id == role_id))
|
||||||
fetched = result.scalar_one()
|
fetched = result.scalar_one()
|
||||||
assert fetched.name == "test_helper_role"
|
assert fetched.name == "test_helper_role"
|
||||||
|
|
||||||
@@ -237,8 +290,9 @@ class TestCreateDbSession:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_tables_dropped_after_session(self):
|
async def test_tables_dropped_after_session(self):
|
||||||
"""Tables are dropped after session closes when drop_tables=True."""
|
"""Tables are dropped after session closes when drop_tables=True."""
|
||||||
|
role_id = uuid.uuid4()
|
||||||
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||||
role = Role(id=9002, name="will_be_dropped")
|
role = Role(id=role_id, name="will_be_dropped")
|
||||||
session.add(role)
|
session.add(role)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
@@ -250,14 +304,15 @@ class TestCreateDbSession:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_tables_preserved_when_drop_disabled(self):
|
async def test_tables_preserved_when_drop_disabled(self):
|
||||||
"""Tables are preserved when drop_tables=False."""
|
"""Tables are preserved when drop_tables=False."""
|
||||||
|
role_id = uuid.uuid4()
|
||||||
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
|
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
|
||||||
role = Role(id=9003, name="preserved_role")
|
role = Role(id=role_id, name="preserved_role")
|
||||||
session.add(role)
|
session.add(role)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
# Create another session without dropping
|
# Create another session without dropping
|
||||||
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
|
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
|
||||||
result = await session.execute(select(Role).where(Role.id == 9003))
|
result = await session.execute(select(Role).where(Role.id == role_id))
|
||||||
fetched = result.scalar_one_or_none()
|
fetched = result.scalar_one_or_none()
|
||||||
assert fetched is not None
|
assert fetched is not None
|
||||||
assert fetched.name == "preserved_role"
|
assert fetched.name == "preserved_role"
|
||||||
@@ -265,3 +320,216 @@ class TestCreateDbSession:
|
|||||||
# Cleanup: drop tables manually
|
# Cleanup: drop tables manually
|
||||||
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _:
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_cleanup_truncates_tables(self):
|
||||||
|
"""Tables are truncated after session closes when cleanup=True."""
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
async with create_db_session(
|
||||||
|
DATABASE_URL, Base, cleanup=True, drop_tables=False
|
||||||
|
) as session:
|
||||||
|
role = Role(id=role_id, name="will_be_cleaned")
|
||||||
|
session.add(role)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Data should have been truncated, but tables still exist
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||||
|
result = await session.execute(select(Role))
|
||||||
|
assert result.all() == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetXdistWorker:
|
||||||
|
"""Tests for _get_xdist_worker helper."""
|
||||||
|
|
||||||
|
def test_returns_default_test_db_without_env_var(
|
||||||
|
self, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
|
"""Returns default_test_db when PYTEST_XDIST_WORKER is not set."""
|
||||||
|
monkeypatch.delenv("PYTEST_XDIST_WORKER", raising=False)
|
||||||
|
assert _get_xdist_worker("my_default") == "my_default"
|
||||||
|
|
||||||
|
def test_returns_worker_name(self, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""Returns the worker name from the environment variable."""
|
||||||
|
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw0")
|
||||||
|
assert _get_xdist_worker("ignored") == "gw0"
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkerDatabaseUrl:
|
||||||
|
"""Tests for worker_database_url helper."""
|
||||||
|
|
||||||
|
def test_appends_default_test_db_without_xdist(
|
||||||
|
self, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
|
"""default_test_db is appended when not running under xdist."""
|
||||||
|
monkeypatch.delenv("PYTEST_XDIST_WORKER", raising=False)
|
||||||
|
url = "postgresql+asyncpg://user:pass@localhost:5432/mydb"
|
||||||
|
result = worker_database_url(url, default_test_db="fallback")
|
||||||
|
assert make_url(result).database == "mydb_fallback"
|
||||||
|
|
||||||
|
def test_appends_worker_id_to_database_name(self, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""Worker name is appended to the database name."""
|
||||||
|
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw0")
|
||||||
|
url = "postgresql+asyncpg://user:pass@localhost:5432/db"
|
||||||
|
result = worker_database_url(url, default_test_db="unused")
|
||||||
|
assert make_url(result).database == "db_gw0"
|
||||||
|
|
||||||
|
def test_preserves_url_components(self, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""Host, port, username, password, and driver are preserved."""
|
||||||
|
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw2")
|
||||||
|
url = "postgresql+asyncpg://myuser:secret@dbhost:6543/testdb"
|
||||||
|
result = make_url(worker_database_url(url, default_test_db="unused"))
|
||||||
|
|
||||||
|
assert result.drivername == "postgresql+asyncpg"
|
||||||
|
assert result.username == "myuser"
|
||||||
|
assert result.password == "secret"
|
||||||
|
assert result.host == "dbhost"
|
||||||
|
assert result.port == 6543
|
||||||
|
assert result.database == "testdb_gw2"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateWorkerDatabase:
|
||||||
|
"""Tests for create_worker_database context manager."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_creates_default_db_without_xdist(
|
||||||
|
self, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
|
"""Without xdist, creates a database suffixed with default_test_db."""
|
||||||
|
monkeypatch.delenv("PYTEST_XDIST_WORKER", raising=False)
|
||||||
|
default_test_db = "no_xdist_default"
|
||||||
|
expected_db = make_url(
|
||||||
|
worker_database_url(DATABASE_URL, default_test_db=default_test_db)
|
||||||
|
).database
|
||||||
|
|
||||||
|
async with create_worker_database(
|
||||||
|
DATABASE_URL, default_test_db=default_test_db
|
||||||
|
) as url:
|
||||||
|
assert make_url(url).database == expected_db
|
||||||
|
|
||||||
|
# Verify the database exists while inside the context
|
||||||
|
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
result = await conn.execute(
|
||||||
|
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||||
|
{"name": expected_db},
|
||||||
|
)
|
||||||
|
assert result.scalar() == 1
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
# After context exit the database should be dropped
|
||||||
|
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
result = await conn.execute(
|
||||||
|
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||||
|
{"name": expected_db},
|
||||||
|
)
|
||||||
|
assert result.scalar() is None
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_creates_and_drops_worker_database(
|
||||||
|
self, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
|
"""Worker database exists inside the context and is dropped after."""
|
||||||
|
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw_test_create")
|
||||||
|
expected_db = make_url(
|
||||||
|
worker_database_url(DATABASE_URL, default_test_db="unused")
|
||||||
|
).database
|
||||||
|
|
||||||
|
async with create_worker_database(DATABASE_URL) as url:
|
||||||
|
assert make_url(url).database == expected_db
|
||||||
|
|
||||||
|
# Verify the database exists while inside the context
|
||||||
|
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
result = await conn.execute(
|
||||||
|
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||||
|
{"name": expected_db},
|
||||||
|
)
|
||||||
|
assert result.scalar() == 1
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
# After context exit the database should be dropped
|
||||||
|
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
result = await conn.execute(
|
||||||
|
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||||
|
{"name": expected_db},
|
||||||
|
)
|
||||||
|
assert result.scalar() is None
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_cleans_up_stale_database(self, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""A pre-existing worker database is dropped and recreated."""
|
||||||
|
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw_test_stale")
|
||||||
|
expected_db = make_url(
|
||||||
|
worker_database_url(DATABASE_URL, default_test_db="unused")
|
||||||
|
).database
|
||||||
|
|
||||||
|
# Pre-create the database to simulate a stale leftover
|
||||||
|
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
await conn.execute(text(f"DROP DATABASE IF EXISTS {expected_db}"))
|
||||||
|
await conn.execute(text(f"CREATE DATABASE {expected_db}"))
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
# Should succeed despite the database already existing
|
||||||
|
async with create_worker_database(DATABASE_URL) as url:
|
||||||
|
assert make_url(url).database == expected_db
|
||||||
|
|
||||||
|
# Verify cleanup after context exit
|
||||||
|
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
result = await conn.execute(
|
||||||
|
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||||
|
{"name": expected_db},
|
||||||
|
)
|
||||||
|
assert result.scalar() is None
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
class TestCleanupTables:
|
||||||
|
"""Tests for cleanup_tables helper."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_truncates_all_tables(self):
|
||||||
|
"""All table rows are removed after cleanup_tables."""
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||||
|
role = Role(id=uuid.uuid4(), name="cleanup_role")
|
||||||
|
session.add(role)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
username="cleanup_user",
|
||||||
|
email="cleanup@test.com",
|
||||||
|
role_id=role.id,
|
||||||
|
)
|
||||||
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Verify rows exist
|
||||||
|
roles_count = await RoleCrud.count(session)
|
||||||
|
users_count = await UserCrud.count(session)
|
||||||
|
assert roles_count == 1
|
||||||
|
assert users_count == 1
|
||||||
|
|
||||||
|
await cleanup_tables(session, Base)
|
||||||
|
|
||||||
|
# Verify tables are empty
|
||||||
|
roles_count = await RoleCrud.count(session)
|
||||||
|
users_count = await UserCrud.count(session)
|
||||||
|
assert roles_count == 0
|
||||||
|
assert users_count == 0
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_noop_for_empty_metadata(self):
|
||||||
|
"""cleanup_tables does not raise when metadata has no tables."""
|
||||||
|
|
||||||
|
class EmptyBase(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||||
|
# Should not raise
|
||||||
|
await cleanup_tables(session, EmptyBase)
|
||||||
|
|||||||
@@ -5,7 +5,9 @@ from pydantic import ValidationError
|
|||||||
|
|
||||||
from fastapi_toolsets.schemas import (
|
from fastapi_toolsets.schemas import (
|
||||||
ApiError,
|
ApiError,
|
||||||
|
CursorPagination,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
|
OffsetPagination,
|
||||||
PaginatedResponse,
|
PaginatedResponse,
|
||||||
Pagination,
|
Pagination,
|
||||||
Response,
|
Response,
|
||||||
@@ -46,6 +48,31 @@ class TestApiError:
|
|||||||
assert error.desc == "The resource was not found."
|
assert error.desc == "The resource was not found."
|
||||||
assert error.err_code == "RES-404"
|
assert error.err_code == "RES-404"
|
||||||
|
|
||||||
|
def test_data_defaults_to_none(self):
|
||||||
|
"""ApiError data field defaults to None."""
|
||||||
|
error = ApiError(
|
||||||
|
code=404,
|
||||||
|
msg="Not Found",
|
||||||
|
desc="The resource was not found.",
|
||||||
|
err_code="RES-404",
|
||||||
|
)
|
||||||
|
assert error.data is None
|
||||||
|
|
||||||
|
def test_create_with_data(self):
|
||||||
|
"""ApiError can be created with a data payload."""
|
||||||
|
error = ApiError(
|
||||||
|
code=422,
|
||||||
|
msg="Validation Error",
|
||||||
|
desc="2 validation error(s) detected",
|
||||||
|
err_code="VAL-422",
|
||||||
|
data={
|
||||||
|
"errors": [{"field": "name", "message": "required", "type": "missing"}]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert error.data == {
|
||||||
|
"errors": [{"field": "name", "message": "required", "type": "missing"}]
|
||||||
|
}
|
||||||
|
|
||||||
def test_requires_all_fields(self):
|
def test_requires_all_fields(self):
|
||||||
"""ApiError requires all fields."""
|
"""ApiError requires all fields."""
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
@@ -129,12 +156,12 @@ class TestErrorResponse:
|
|||||||
assert data["description"] == "Details"
|
assert data["description"] == "Details"
|
||||||
|
|
||||||
|
|
||||||
class TestPagination:
|
class TestOffsetPagination:
|
||||||
"""Tests for Pagination schema."""
|
"""Tests for OffsetPagination schema (canonical name for offset-based pagination)."""
|
||||||
|
|
||||||
def test_create_pagination(self):
|
def test_create_pagination(self):
|
||||||
"""Create Pagination with all fields."""
|
"""Create OffsetPagination with all fields."""
|
||||||
pagination = Pagination(
|
pagination = OffsetPagination(
|
||||||
total_count=100,
|
total_count=100,
|
||||||
items_per_page=10,
|
items_per_page=10,
|
||||||
page=1,
|
page=1,
|
||||||
@@ -148,7 +175,7 @@ class TestPagination:
|
|||||||
|
|
||||||
def test_last_page_has_more_false(self):
|
def test_last_page_has_more_false(self):
|
||||||
"""Last page has has_more=False."""
|
"""Last page has has_more=False."""
|
||||||
pagination = Pagination(
|
pagination = OffsetPagination(
|
||||||
total_count=25,
|
total_count=25,
|
||||||
items_per_page=10,
|
items_per_page=10,
|
||||||
page=3,
|
page=3,
|
||||||
@@ -158,8 +185,8 @@ class TestPagination:
|
|||||||
assert pagination.has_more is False
|
assert pagination.has_more is False
|
||||||
|
|
||||||
def test_serialization(self):
|
def test_serialization(self):
|
||||||
"""Pagination serializes correctly."""
|
"""OffsetPagination serializes correctly."""
|
||||||
pagination = Pagination(
|
pagination = OffsetPagination(
|
||||||
total_count=50,
|
total_count=50,
|
||||||
items_per_page=20,
|
items_per_page=20,
|
||||||
page=2,
|
page=2,
|
||||||
@@ -172,6 +199,77 @@ class TestPagination:
|
|||||||
assert data["page"] == 2
|
assert data["page"] == 2
|
||||||
assert data["has_more"] is True
|
assert data["has_more"] is True
|
||||||
|
|
||||||
|
def test_pagination_alias_is_offset_pagination(self):
|
||||||
|
"""Pagination is a backward-compatible alias for OffsetPagination."""
|
||||||
|
assert Pagination is OffsetPagination
|
||||||
|
|
||||||
|
def test_pagination_alias_constructs_offset_pagination(self):
|
||||||
|
"""Code using Pagination(...) still works unchanged."""
|
||||||
|
pagination = Pagination(
|
||||||
|
total_count=10,
|
||||||
|
items_per_page=5,
|
||||||
|
page=2,
|
||||||
|
has_more=False,
|
||||||
|
)
|
||||||
|
assert isinstance(pagination, OffsetPagination)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCursorPagination:
|
||||||
|
"""Tests for CursorPagination schema."""
|
||||||
|
|
||||||
|
def test_create_with_next_cursor(self):
|
||||||
|
"""CursorPagination with a next cursor indicates more pages."""
|
||||||
|
pagination = CursorPagination(
|
||||||
|
next_cursor="eyJ2YWx1ZSI6ICIxMjMifQ==",
|
||||||
|
items_per_page=20,
|
||||||
|
has_more=True,
|
||||||
|
)
|
||||||
|
assert pagination.next_cursor == "eyJ2YWx1ZSI6ICIxMjMifQ=="
|
||||||
|
assert pagination.prev_cursor is None
|
||||||
|
assert pagination.items_per_page == 20
|
||||||
|
assert pagination.has_more is True
|
||||||
|
|
||||||
|
def test_create_last_page(self):
|
||||||
|
"""CursorPagination for the last page has next_cursor=None and has_more=False."""
|
||||||
|
pagination = CursorPagination(
|
||||||
|
next_cursor=None,
|
||||||
|
items_per_page=20,
|
||||||
|
has_more=False,
|
||||||
|
)
|
||||||
|
assert pagination.next_cursor is None
|
||||||
|
assert pagination.has_more is False
|
||||||
|
|
||||||
|
def test_prev_cursor_defaults_to_none(self):
|
||||||
|
"""prev_cursor defaults to None."""
|
||||||
|
pagination = CursorPagination(
|
||||||
|
next_cursor=None, items_per_page=10, has_more=False
|
||||||
|
)
|
||||||
|
assert pagination.prev_cursor is None
|
||||||
|
|
||||||
|
def test_prev_cursor_can_be_set(self):
|
||||||
|
"""prev_cursor can be explicitly set."""
|
||||||
|
pagination = CursorPagination(
|
||||||
|
next_cursor="next123",
|
||||||
|
prev_cursor="prev456",
|
||||||
|
items_per_page=10,
|
||||||
|
has_more=True,
|
||||||
|
)
|
||||||
|
assert pagination.prev_cursor == "prev456"
|
||||||
|
|
||||||
|
def test_serialization(self):
|
||||||
|
"""CursorPagination serializes correctly."""
|
||||||
|
pagination = CursorPagination(
|
||||||
|
next_cursor="abc123",
|
||||||
|
prev_cursor="xyz789",
|
||||||
|
items_per_page=20,
|
||||||
|
has_more=True,
|
||||||
|
)
|
||||||
|
data = pagination.model_dump()
|
||||||
|
assert data["next_cursor"] == "abc123"
|
||||||
|
assert data["prev_cursor"] == "xyz789"
|
||||||
|
assert data["items_per_page"] == 20
|
||||||
|
assert data["has_more"] is True
|
||||||
|
|
||||||
|
|
||||||
class TestPaginatedResponse:
|
class TestPaginatedResponse:
|
||||||
"""Tests for PaginatedResponse schema."""
|
"""Tests for PaginatedResponse schema."""
|
||||||
@@ -189,6 +287,7 @@ class TestPaginatedResponse:
|
|||||||
pagination=pagination,
|
pagination=pagination,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert isinstance(response.pagination, OffsetPagination)
|
||||||
assert len(response.data) == 2
|
assert len(response.data) == 2
|
||||||
assert response.pagination.total_count == 30
|
assert response.pagination.total_count == 30
|
||||||
assert response.status == ResponseStatus.SUCCESS
|
assert response.status == ResponseStatus.SUCCESS
|
||||||
@@ -222,6 +321,7 @@ class TestPaginatedResponse:
|
|||||||
pagination=pagination,
|
pagination=pagination,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert isinstance(response.pagination, OffsetPagination)
|
||||||
assert response.data == []
|
assert response.data == []
|
||||||
assert response.pagination.total_count == 0
|
assert response.pagination.total_count == 0
|
||||||
|
|
||||||
@@ -265,6 +365,36 @@ class TestPaginatedResponse:
|
|||||||
assert data["data"] == ["item1", "item2"]
|
assert data["data"] == ["item1", "item2"]
|
||||||
assert data["pagination"]["page"] == 5
|
assert data["pagination"]["page"] == 5
|
||||||
|
|
||||||
|
def test_pagination_field_accepts_offset_pagination(self):
|
||||||
|
"""PaginatedResponse.pagination accepts OffsetPagination."""
|
||||||
|
response = PaginatedResponse(
|
||||||
|
data=[1, 2],
|
||||||
|
pagination=OffsetPagination(
|
||||||
|
total_count=2, items_per_page=10, page=1, has_more=False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert isinstance(response.pagination, OffsetPagination)
|
||||||
|
|
||||||
|
def test_pagination_field_accepts_cursor_pagination(self):
|
||||||
|
"""PaginatedResponse.pagination accepts CursorPagination."""
|
||||||
|
response = PaginatedResponse(
|
||||||
|
data=[1, 2],
|
||||||
|
pagination=CursorPagination(
|
||||||
|
next_cursor=None, items_per_page=10, has_more=False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert isinstance(response.pagination, CursorPagination)
|
||||||
|
|
||||||
|
def test_pagination_alias_accepted(self):
|
||||||
|
"""Constructing PaginatedResponse with Pagination (alias) still works."""
|
||||||
|
response = PaginatedResponse(
|
||||||
|
data=[],
|
||||||
|
pagination=Pagination(
|
||||||
|
total_count=0, items_per_page=10, page=1, has_more=False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert isinstance(response.pagination, OffsetPagination)
|
||||||
|
|
||||||
|
|
||||||
class TestFromAttributes:
|
class TestFromAttributes:
|
||||||
"""Tests for from_attributes config (ORM mode)."""
|
"""Tests for from_attributes config (ORM mode)."""
|
||||||
|
|||||||
97
zensical.toml
Normal file
97
zensical.toml
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
[project]
|
||||||
|
site_name = "FastAPI Toolsets"
|
||||||
|
site_description = "Production-ready utilities for FastAPI applications."
|
||||||
|
site_author = "d3vyce"
|
||||||
|
site_url = "https://fastapi-toolsets.d3vyce.fr"
|
||||||
|
copyright = "Copyright © 2026 d3vyce"
|
||||||
|
repo_url = "https://github.com/d3vyce/fastapi-toolsets"
|
||||||
|
|
||||||
|
[project.theme]
|
||||||
|
custom_dir = "docs/overrides"
|
||||||
|
language = "en"
|
||||||
|
features = [
|
||||||
|
"announce.dismiss",
|
||||||
|
"content.action.view",
|
||||||
|
"content.code.annotate",
|
||||||
|
"content.code.copy",
|
||||||
|
"content.code.select",
|
||||||
|
"content.footnote.tooltips",
|
||||||
|
"content.tabs.link",
|
||||||
|
"content.tooltips",
|
||||||
|
"navigation.footer",
|
||||||
|
"navigation.indexes",
|
||||||
|
"navigation.instant",
|
||||||
|
"navigation.instant.prefetch",
|
||||||
|
"navigation.path",
|
||||||
|
"navigation.sections",
|
||||||
|
"navigation.tabs",
|
||||||
|
"navigation.top",
|
||||||
|
"navigation.tracking",
|
||||||
|
"search.highlight",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[project.theme.palette]]
|
||||||
|
scheme = "default"
|
||||||
|
toggle.icon = "lucide/sun"
|
||||||
|
toggle.name = "Switch to dark mode"
|
||||||
|
|
||||||
|
[[project.theme.palette]]
|
||||||
|
scheme = "slate"
|
||||||
|
toggle.icon = "lucide/moon"
|
||||||
|
toggle.name = "Switch to light mode"
|
||||||
|
|
||||||
|
[project.theme.font]
|
||||||
|
text = "Inter"
|
||||||
|
code = "Jetbrains Mono"
|
||||||
|
|
||||||
|
[project.theme.icon]
|
||||||
|
repo = "fontawesome/brands/github"
|
||||||
|
|
||||||
|
[project.plugins.mkdocstrings.handlers.python]
|
||||||
|
inventories = ["https://docs.python.org/3/objects.inv"]
|
||||||
|
paths = ["src"]
|
||||||
|
|
||||||
|
[project.plugins.mkdocstrings.handlers.python.options]
|
||||||
|
docstring_style = "google"
|
||||||
|
inherited_members = true
|
||||||
|
show_source = false
|
||||||
|
show_root_heading = true
|
||||||
|
|
||||||
|
[project.markdown_extensions]
|
||||||
|
abbr = {}
|
||||||
|
admonition = {}
|
||||||
|
attr_list = {}
|
||||||
|
def_list = {}
|
||||||
|
footnotes = {}
|
||||||
|
md_in_html = {}
|
||||||
|
"pymdownx.arithmatex" = {generic = true}
|
||||||
|
"pymdownx.betterem" = {}
|
||||||
|
"pymdownx.caret" = {}
|
||||||
|
"pymdownx.details" = {}
|
||||||
|
"pymdownx.emoji" = {}
|
||||||
|
"pymdownx.inlinehilite" = {}
|
||||||
|
"pymdownx.keys" = {}
|
||||||
|
"pymdownx.magiclink" = {}
|
||||||
|
"pymdownx.mark" = {}
|
||||||
|
"pymdownx.smartsymbols" = {}
|
||||||
|
"pymdownx.tasklist" = {custom_checkbox = true}
|
||||||
|
"pymdownx.tilde" = {}
|
||||||
|
|
||||||
|
[project.markdown_extensions."pymdownx.highlight"]
|
||||||
|
anchor_linenums = true
|
||||||
|
line_spans = "__span"
|
||||||
|
pygments_lang_class = true
|
||||||
|
|
||||||
|
[project.markdown_extensions."pymdownx.superfences"]
|
||||||
|
custom_fences = [{name = "mermaid", class = "mermaid"}]
|
||||||
|
|
||||||
|
[project.markdown_extensions."pymdownx.tabbed"]
|
||||||
|
alternate_style = true
|
||||||
|
combine_header_slug = true
|
||||||
|
|
||||||
|
[project.markdown_extensions."toc"]
|
||||||
|
permalink = true
|
||||||
|
|
||||||
|
[project.markdown_extensions."pymdownx.snippets"]
|
||||||
|
base_path = ["."]
|
||||||
|
check_paths = true
|
||||||
Reference in New Issue
Block a user