mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
Compare commits
70 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
8c8911fb27
|
|||
|
|
c0c3b38054 | ||
|
e17d385910
|
|||
|
|
6cf7df55ef | ||
|
|
7482bc5dad | ||
|
|
9d07dfea85 | ||
|
|
31678935aa | ||
|
|
823a0b3e36 | ||
|
1591cd3d64
|
|||
|
|
6714ceeb92 | ||
|
|
73fae04333 | ||
|
|
32ed36e102 | ||
|
|
48567310bc | ||
|
|
de51ed4675 | ||
|
|
794767edbb | ||
|
|
9c136f05bb | ||
|
3299a439fe
|
|||
|
|
d5b22a72fd | ||
|
|
c32f2e18be | ||
|
d971261f98
|
|||
|
|
74a54b7396 | ||
|
|
19805ab376 | ||
|
|
d4498e2063 | ||
| f59c1a17e2 | |||
|
|
8982ba18e3 | ||
|
|
71fe6f478f | ||
|
|
1cfbf14986 | ||
|
|
e3ff535b7e | ||
|
|
8825c772ce | ||
|
|
c8c263ca8f | ||
|
2020fa2f92
|
|||
|
|
1ea316bef4 | ||
|
|
ced1a655f2 | ||
|
|
290b2a06ec | ||
|
|
baa9711665 | ||
|
d526969d0e
|
|||
|
|
e24153053e | ||
|
348ed4c148
|
|||
|
bd6e90de1b
|
|||
|
|
4404fb3df9 | ||
|
|
f68793fbdb | ||
|
|
3a69c3c788 | ||
|
e861a0a49a
|
|||
|
|
cb2cf572e0 | ||
|
494869a172
|
|||
|
|
e0bc93096d | ||
|
1ff94eb9d3
|
|||
|
|
97ab10edcd | ||
|
|
3ff7ff18bb | ||
|
0f50c8a0f0
|
|||
|
|
691fb78fda | ||
|
|
34ef4da317 | ||
|
|
8c287b3ce7 | ||
|
54f5479c24
|
|||
|
|
f467754df1 | ||
|
b57ce40b05
|
|||
|
5264631550
|
|||
|
a76f7c439d
|
|||
|
|
d14551781c | ||
|
|
577e087321 | ||
|
aa72dc2eb5
|
|||
|
|
1a98e36909 | ||
|
ba5180a73b
|
|||
|
a9f486d905
|
|||
|
53e80cd0d5
|
|||
|
|
45001767aa | ||
|
|
cd551b6bff | ||
|
|
718a12be28 | ||
|
|
fa16bf1bff | ||
|
|
c4a227f9fc |
14
.github/dependabot.yml
vendored
Normal file
14
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
version: 2
|
||||||
|
updates:
|
||||||
|
- package-ecosystem: "github-actions"
|
||||||
|
directory: "/"
|
||||||
|
schedule:
|
||||||
|
interval: "weekly"
|
||||||
|
commit-message:
|
||||||
|
prefix: ⬆
|
||||||
|
- package-ecosystem: "uv"
|
||||||
|
directory: "/"
|
||||||
|
schedule:
|
||||||
|
interval: "weekly"
|
||||||
|
commit-message:
|
||||||
|
prefix: ⬆
|
||||||
8
.github/workflows/build-release.yml
vendored
8
.github/workflows/build-release.yml
vendored
@@ -11,16 +11,16 @@ jobs:
|
|||||||
permissions:
|
permissions:
|
||||||
id-token: write
|
id-token: write
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@v5
|
uses: astral-sh/setup-uv@v7
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
run: uv python install 3.13
|
run: uv python install 3.14
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: uv sync
|
run: uv sync --group dev
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
run: uv build
|
run: uv build
|
||||||
|
|||||||
32
.github/workflows/ci.yml
vendored
32
.github/workflows/ci.yml
vendored
@@ -15,16 +15,16 @@ jobs:
|
|||||||
name: Lint (Ruff)
|
name: Lint (Ruff)
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@v5
|
uses: astral-sh/setup-uv@v7
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
run: uv python install 3.13
|
run: uv python install 3.13
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: uv sync --extra dev
|
run: uv sync --group dev
|
||||||
|
|
||||||
- name: Run Ruff linter
|
- name: Run Ruff linter
|
||||||
run: uv run ruff check .
|
run: uv run ruff check .
|
||||||
@@ -36,16 +36,16 @@ jobs:
|
|||||||
name: Type Check (ty)
|
name: Type Check (ty)
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@v5
|
uses: astral-sh/setup-uv@v7
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
run: uv python install 3.13
|
run: uv python install 3.13
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: uv sync --extra dev
|
run: uv sync --group dev
|
||||||
|
|
||||||
- name: Run ty
|
- name: Run ty
|
||||||
run: uv run ty check
|
run: uv run ty check
|
||||||
@@ -56,7 +56,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.11", "3.12", "3.13"]
|
python-version: ["3.11", "3.12", "3.13", "3.14"]
|
||||||
|
|
||||||
services:
|
services:
|
||||||
postgres:
|
postgres:
|
||||||
@@ -74,27 +74,35 @@ jobs:
|
|||||||
--health-retries 5
|
--health-retries 5
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@v5
|
uses: astral-sh/setup-uv@v7
|
||||||
|
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
run: uv python install ${{ matrix.python-version }}
|
run: uv python install ${{ matrix.python-version }}
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: uv sync --extra dev
|
run: uv sync --group dev
|
||||||
|
|
||||||
- name: Run tests with coverage
|
- name: Run tests with coverage
|
||||||
env:
|
env:
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@localhost:5432/test_db
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@localhost:5432/test_db
|
||||||
run: |
|
run: |
|
||||||
uv run pytest --cov --cov-report=xml --cov-report=term-missing
|
uv run pytest --cov --cov-report=xml --cov-report=term-missing --junitxml=junit.xml -o junit_family=legacy
|
||||||
|
|
||||||
- name: Upload coverage to Codecov
|
- name: Upload coverage to Codecov
|
||||||
if: matrix.python-version == '3.13'
|
if: matrix.python-version == '3.14'
|
||||||
uses: codecov/codecov-action@v5
|
uses: codecov/codecov-action@v5
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
|
report_type: coverage
|
||||||
files: ./coverage.xml
|
files: ./coverage.xml
|
||||||
fail_ci_if_error: false
|
fail_ci_if_error: false
|
||||||
|
|
||||||
|
- name: Upload test results to Codecov
|
||||||
|
if: matrix.python-version == '3.14'
|
||||||
|
uses: codecov/codecov-action@v5
|
||||||
|
with:
|
||||||
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
|
report_type: test_results
|
||||||
|
|||||||
38
.github/workflows/docs.yml
vendored
Normal file
38
.github/workflows/docs.yml
vendored
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
name: Documentation
|
||||||
|
|
||||||
|
on:
|
||||||
|
release:
|
||||||
|
types: [published]
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
pages: write
|
||||||
|
id-token: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
deploy:
|
||||||
|
environment:
|
||||||
|
name: github-pages
|
||||||
|
url: ${{ steps.deployment.outputs.page_url }}
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/configure-pages@v5
|
||||||
|
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v7
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
run: uv python install 3.13
|
||||||
|
|
||||||
|
- run: uv sync --group dev
|
||||||
|
|
||||||
|
- run: uv run zensical build --clean
|
||||||
|
|
||||||
|
- uses: actions/upload-pages-artifact@v4
|
||||||
|
with:
|
||||||
|
path: site
|
||||||
|
|
||||||
|
- uses: actions/deploy-pages@v4
|
||||||
|
id: deployment
|
||||||
@@ -1 +1 @@
|
|||||||
3.13
|
3.14
|
||||||
|
|||||||
36
README.md
36
README.md
@@ -1,6 +1,6 @@
|
|||||||
# FastAPI Toolsets
|
# FastAPI Toolsets
|
||||||
|
|
||||||
FastAPI Toolsets provides production-ready utilities for FastAPI applications built with async SQLAlchemy and PostgreSQL. It includes generic CRUD operations, a fixture system with dependency resolution, a Django-like CLI, standardized API responses, and structured exception handling with automatic OpenAPI documentation.
|
A modular collection of production-ready utilities for FastAPI. Install only what you need — from async CRUD and database helpers to CLI tooling, Prometheus metrics, and pytest fixtures. Each module is independently installable via optional extras, keeping your dependency footprint minimal.
|
||||||
|
|
||||||
[](https://github.com/d3vyce/fastapi-toolsets/actions/workflows/ci.yml)
|
[](https://github.com/d3vyce/fastapi-toolsets/actions/workflows/ci.yml)
|
||||||
[](https://codecov.io/gh/d3vyce/fastapi-toolsets)
|
[](https://codecov.io/gh/d3vyce/fastapi-toolsets)
|
||||||
@@ -20,17 +20,43 @@ FastAPI Toolsets provides production-ready utilities for FastAPI applications bu
|
|||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
|
The base package includes the core modules (CRUD, database, schemas, exceptions, fixtures, dependencies, logging):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv add fastapi-toolsets
|
uv add fastapi-toolsets
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Install only the extras you need:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add "fastapi-toolsets[cli]" # CLI (typer)
|
||||||
|
uv add "fastapi-toolsets[metrics]" # Prometheus metrics (prometheus_client)
|
||||||
|
uv add "fastapi-toolsets[pytest]" # Pytest helpers (httpx, pytest-xdist)
|
||||||
|
```
|
||||||
|
|
||||||
|
Or install everything:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add "fastapi-toolsets[all]"
|
||||||
|
```
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- **CRUD**: Generic async CRUD operations with `CrudFactory`
|
### Core
|
||||||
- **Fixtures**: Fixture system with dependency management, context support and pytest integration
|
|
||||||
- **CLI**: Django-like command-line interface for fixtures and custom commands
|
- **CRUD**: Generic async CRUD operations with `CrudFactory`, built-in search with relationship traversal
|
||||||
- **Standardized API Responses**: Consistent response format across your API
|
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
|
||||||
|
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
|
||||||
|
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration
|
||||||
|
- **Standardized API Responses**: Consistent response format with `Response`, `PaginatedResponse`, and `PydanticBase`
|
||||||
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
|
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
|
||||||
|
- **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger`
|
||||||
|
|
||||||
|
### Optional
|
||||||
|
|
||||||
|
- **CLI**: Django-like command-line interface with fixture management and custom commands support
|
||||||
|
- **Metrics**: Prometheus metrics endpoint with provider/collector registry
|
||||||
|
- **Pytest Helpers**: Async test client, database session management, `pytest-xdist` support, and table cleanup utilities
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
|
|||||||
67
docs/index.md
Normal file
67
docs/index.md
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
# FastAPI Toolsets
|
||||||
|
|
||||||
|
A modular collection of production-ready utilities for FastAPI. Install only what you need — from async CRUD and database helpers to CLI tooling, Prometheus metrics, and pytest fixtures. Each module is independently installable via optional extras, keeping your dependency footprint minimal.
|
||||||
|
|
||||||
|
[](https://github.com/d3vyce/fastapi-toolsets/actions/workflows/ci.yml)
|
||||||
|
[](https://codecov.io/gh/d3vyce/fastapi-toolsets)
|
||||||
|
[](https://github.com/astral-sh/ty)
|
||||||
|
[](https://github.com/astral-sh/uv)
|
||||||
|
[](https://github.com/astral-sh/ruff)
|
||||||
|
[](https://www.python.org/downloads/)
|
||||||
|
[](https://opensource.org/licenses/MIT)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Documentation**: [https://fastapi-toolsets.d3vyce.fr](https://fastapi-toolsets.d3vyce.fr)
|
||||||
|
|
||||||
|
**Source Code**: [https://github.com/d3vyce/fastapi-toolsets](https://github.com/d3vyce/fastapi-toolsets)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
The base package includes the core modules (CRUD, database, schemas, exceptions, fixtures, dependencies, logging):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add fastapi-toolsets
|
||||||
|
```
|
||||||
|
|
||||||
|
Install only the extras you need:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add "fastapi-toolsets[cli]" # CLI (typer)
|
||||||
|
uv add "fastapi-toolsets[metrics]" # Prometheus metrics (prometheus_client)
|
||||||
|
uv add "fastapi-toolsets[pytest]" # Pytest helpers (httpx, pytest-xdist)
|
||||||
|
```
|
||||||
|
|
||||||
|
Or install everything:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add "fastapi-toolsets[all]"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
### Core
|
||||||
|
|
||||||
|
- **CRUD**: Generic async CRUD operations with `CrudFactory`, built-in search with relationship traversal
|
||||||
|
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
|
||||||
|
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
|
||||||
|
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration
|
||||||
|
- **Standardized API Responses**: Consistent response format with `Response`, `PaginatedResponse`, and `PydanticBase`
|
||||||
|
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
|
||||||
|
- **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger`
|
||||||
|
|
||||||
|
### Optional
|
||||||
|
|
||||||
|
- **CLI**: Django-like command-line interface with fixture management and custom commands support
|
||||||
|
- **Metrics**: Prometheus metrics endpoint with provider/collector registry
|
||||||
|
- **Pytest Helpers**: Async test client, database session management, `pytest-xdist` support, and table cleanup utilities
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT License - see [LICENSE](LICENSE) for details.
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Contributions are welcome! Please feel free to submit issues and pull requests.
|
||||||
93
docs/module/cli.md
Normal file
93
docs/module/cli.md
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
# CLI
|
||||||
|
|
||||||
|
Typer-based command-line interface for managing your FastAPI application, with built-in fixture commands integration.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
=== "uv"
|
||||||
|
``` bash
|
||||||
|
uv add "fastapi-toolsets[cli]"
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "pip"
|
||||||
|
``` bash
|
||||||
|
pip install "fastapi-toolsets[cli]"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `cli` module provides a `manager` entry point built with [Typer](https://typer.tiangolo.com/). It allow custom commands to be added in addition of the fixture commands when a [`FixtureRegistry`](../reference/fixtures.md#fastapi_toolsets.fixtures.registry.FixtureRegistry) and a database context are configured.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Configure the CLI in your `pyproject.toml`:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[tool.fastapi-toolsets]
|
||||||
|
cli = "myapp.cli:cli" # Custom Typer app
|
||||||
|
fixtures = "myapp.fixtures:registry" # FixtureRegistry instance
|
||||||
|
db_context = "myapp.db:db_context" # Async context manager for sessions
|
||||||
|
```
|
||||||
|
|
||||||
|
All fields are optional. Without configuration, the `manager` command still works but no command are available.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Manager commands
|
||||||
|
manager --help
|
||||||
|
|
||||||
|
Usage: manager [OPTIONS] COMMAND [ARGS]...
|
||||||
|
|
||||||
|
FastAPI utilities CLI.
|
||||||
|
|
||||||
|
╭─ Options ────────────────────────────────────────────────────────────────────────╮
|
||||||
|
│ --install-completion Install completion for the current shell. │
|
||||||
|
│ --show-completion Show completion for the current shell, to copy it │
|
||||||
|
│ or customize the installation. │
|
||||||
|
│ --help Show this message and exit. │
|
||||||
|
╰──────────────────────────────────────────────────────────────────────────────────╯
|
||||||
|
╭─ Commands ───────────────────────────────────────────────────────────────────────╮
|
||||||
|
│ check-db │
|
||||||
|
│ fixtures Manage database fixtures. │
|
||||||
|
╰──────────────────────────────────────────────────────────────────────────────────╯
|
||||||
|
|
||||||
|
# Fixtures commands
|
||||||
|
manager fixtures --help
|
||||||
|
|
||||||
|
Usage: manager fixtures [OPTIONS] COMMAND [ARGS]...
|
||||||
|
|
||||||
|
Manage database fixtures.
|
||||||
|
|
||||||
|
╭─ Options ────────────────────────────────────────────────────────────────────────╮
|
||||||
|
│ --help Show this message and exit. │
|
||||||
|
╰──────────────────────────────────────────────────────────────────────────────────╯
|
||||||
|
╭─ Commands ───────────────────────────────────────────────────────────────────────╮
|
||||||
|
│ list List all registered fixtures. │
|
||||||
|
│ load Load fixtures into the database. │
|
||||||
|
╰──────────────────────────────────────────────────────────────────────────────────╯
|
||||||
|
```
|
||||||
|
|
||||||
|
## Custom CLI
|
||||||
|
|
||||||
|
You can extend the CLI by providing your own Typer app. The `manager` entry point will merge your app's commands with the built-in ones:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# myapp/cli.py
|
||||||
|
import typer
|
||||||
|
|
||||||
|
cli = typer.Typer()
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
def hello():
|
||||||
|
print("Hello from my app!")
|
||||||
|
```
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[tool.fastapi-toolsets]
|
||||||
|
cli = "myapp.cli:cli"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/cli.md)
|
||||||
313
docs/module/crud.md
Normal file
313
docs/module/crud.md
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
# CRUD
|
||||||
|
|
||||||
|
Generic async CRUD operations for SQLAlchemy models with search, pagination, and many-to-many support.
|
||||||
|
|
||||||
|
!!! info
|
||||||
|
This module has been coded and tested to be compatible with PostgreSQL only.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `crud` module provides [`AsyncCrud`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud), an abstract base class with a full suite of async database operations, and [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory), a convenience function to instantiate it for a given model.
|
||||||
|
|
||||||
|
## Creating a CRUD class
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.crud import CrudFactory
|
||||||
|
from myapp.models import User
|
||||||
|
|
||||||
|
UserCrud = CrudFactory(model=User)
|
||||||
|
```
|
||||||
|
|
||||||
|
[`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory) dynamically creates a class named `AsyncUserCrud` with `User` as its model.
|
||||||
|
|
||||||
|
## Basic operations
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Create
|
||||||
|
user = await UserCrud.create(session=session, obj=UserCreateSchema(username="alice"))
|
||||||
|
|
||||||
|
# Get one (raises NotFoundError if not found)
|
||||||
|
user = await UserCrud.get(session=session, filters=[User.id == user_id])
|
||||||
|
|
||||||
|
# Get first or None
|
||||||
|
user = await UserCrud.first(session=session, filters=[User.email == email])
|
||||||
|
|
||||||
|
# Get multiple
|
||||||
|
users = await UserCrud.get_multi(session=session, filters=[User.is_active == True])
|
||||||
|
|
||||||
|
# Update
|
||||||
|
user = await UserCrud.update(session=session, obj=UserUpdateSchema(username="bob"), filters=[User.id == user_id])
|
||||||
|
|
||||||
|
# Delete
|
||||||
|
await UserCrud.delete(session=session, filters=[User.id == user_id])
|
||||||
|
|
||||||
|
# Count / exists
|
||||||
|
count = await UserCrud.count(session=session, filters=[User.is_active == True])
|
||||||
|
exists = await UserCrud.exists(session=session, filters=[User.email == email])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pagination
|
||||||
|
|
||||||
|
!!! info "Added in `v1.1` (only offset_pagination via `paginate` if `<v1.1`)"
|
||||||
|
|
||||||
|
Two pagination strategies are available. Both return a [`PaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse) but differ in how they navigate through results.
|
||||||
|
|
||||||
|
| | `offset_paginate` | `cursor_paginate` |
|
||||||
|
|---|---|---|
|
||||||
|
| Total count | Yes | No |
|
||||||
|
| Jump to arbitrary page | Yes | No |
|
||||||
|
| Performance on deep pages | Degrades | Constant |
|
||||||
|
| Stable under concurrent inserts | No | Yes |
|
||||||
|
| Search compatible | Yes | Yes |
|
||||||
|
| Use case | Admin panels, numbered pagination | Feeds, APIs, infinite scroll |
|
||||||
|
|
||||||
|
### Offset pagination
|
||||||
|
|
||||||
|
```python
|
||||||
|
@router.get(
|
||||||
|
"",
|
||||||
|
response_model=PaginatedResponse[User],
|
||||||
|
)
|
||||||
|
async def get_users(
|
||||||
|
session: SessionDep,
|
||||||
|
items_per_page: int = 50,
|
||||||
|
page: int = 1,
|
||||||
|
):
|
||||||
|
return await crud.UserCrud.offset_paginate(
|
||||||
|
session=session,
|
||||||
|
items_per_page=items_per_page,
|
||||||
|
page=page,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
The [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) method returns a [`PaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse) whose `pagination` field is an [`OffsetPagination`](../reference/schemas.md#fastapi_toolsets.schemas.OffsetPagination) object:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"data": ["..."],
|
||||||
|
"pagination": {
|
||||||
|
"total_count": 100,
|
||||||
|
"page": 1,
|
||||||
|
"items_per_page": 20,
|
||||||
|
"has_more": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! warning "Deprecated: `paginate`"
|
||||||
|
The `paginate` function is a backward-compatible alias for `offset_paginate`. This function is **deprecated** and will be removed in **v2.0**.
|
||||||
|
|
||||||
|
### Cursor pagination
|
||||||
|
|
||||||
|
```python
|
||||||
|
@router.get(
|
||||||
|
"",
|
||||||
|
response_model=PaginatedResponse[UserRead],
|
||||||
|
)
|
||||||
|
async def list_users(
|
||||||
|
session: SessionDep,
|
||||||
|
cursor: str | None = None,
|
||||||
|
items_per_page: int = 20,
|
||||||
|
):
|
||||||
|
return await UserCrud.cursor_paginate(
|
||||||
|
session=session,
|
||||||
|
cursor=cursor,
|
||||||
|
items_per_page=items_per_page,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
The [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate) method returns a [`PaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse) whose `pagination` field is a [`CursorPagination`](../reference/schemas.md#fastapi_toolsets.schemas.CursorPagination) object:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"data": ["..."],
|
||||||
|
"pagination": {
|
||||||
|
"next_cursor": "eyJ2YWx1ZSI6ICIzZjQ3YWM2OS0uLi4ifQ==",
|
||||||
|
"prev_cursor": null,
|
||||||
|
"items_per_page": 20,
|
||||||
|
"has_more": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Pass `next_cursor` as the `cursor` query parameter on the next request to advance to the next page. `prev_cursor` is set on pages 2+ and points back to the first item of the current page. Both are `null` when there is no adjacent page.
|
||||||
|
|
||||||
|
#### Choosing a cursor column
|
||||||
|
|
||||||
|
The cursor column is set once on [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory) via the `cursor_column` parameter. It must be monotonically ordered for stable results:
|
||||||
|
|
||||||
|
- Auto-increment integer PKs
|
||||||
|
- UUID v7 PKs
|
||||||
|
- Timestamps
|
||||||
|
|
||||||
|
!!! warning
|
||||||
|
Random UUID v4 PKs are **not** suitable as cursor columns because their ordering is non-deterministic.
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
`cursor_column` is required. Calling [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate) on a CRUD class that has no `cursor_column` configured raises a `ValueError`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Paginate by the primary key
|
||||||
|
PostCrud = CrudFactory(model=Post, cursor_column=Post.id)
|
||||||
|
|
||||||
|
# Paginate by a timestamp column instead
|
||||||
|
PostCrud = CrudFactory(model=Post, cursor_column=Post.created_at)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Search
|
||||||
|
|
||||||
|
Declare searchable fields on the CRUD class. Relationship traversal is supported via tuples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
PostCrud = CrudFactory(
|
||||||
|
model=Post,
|
||||||
|
searchable_fields=[
|
||||||
|
Post.title,
|
||||||
|
Post.content,
|
||||||
|
(Post.author, User.username), # search across relationship
|
||||||
|
],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
This allows searching with both [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) and [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate):
|
||||||
|
|
||||||
|
```python
|
||||||
|
@router.get(
|
||||||
|
"",
|
||||||
|
response_model=PaginatedResponse[User],
|
||||||
|
)
|
||||||
|
async def get_users(
|
||||||
|
session: SessionDep,
|
||||||
|
items_per_page: int = 50,
|
||||||
|
page: int = 1,
|
||||||
|
search: str | None = None,
|
||||||
|
):
|
||||||
|
return await crud.UserCrud.offset_paginate(
|
||||||
|
session=session,
|
||||||
|
items_per_page=items_per_page,
|
||||||
|
page=page,
|
||||||
|
search=search,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
@router.get(
|
||||||
|
"",
|
||||||
|
response_model=PaginatedResponse[User],
|
||||||
|
)
|
||||||
|
async def get_users(
|
||||||
|
session: SessionDep,
|
||||||
|
cursor: str | None = None,
|
||||||
|
items_per_page: int = 50,
|
||||||
|
search: str | None = None,
|
||||||
|
):
|
||||||
|
return await crud.UserCrud.cursor_paginate(
|
||||||
|
session=session,
|
||||||
|
items_per_page=items_per_page,
|
||||||
|
cursor=cursor,
|
||||||
|
search=search,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Relationship loading
|
||||||
|
|
||||||
|
!!! info "Added in `v1.1`"
|
||||||
|
|
||||||
|
By default, SQLAlchemy relationships are not loaded unless explicitly requested. Instead of using `lazy="selectin"` on model definitions (which is implicit and applies globally), define a `default_load_options` on the CRUD class to control loading strategy explicitly.
|
||||||
|
|
||||||
|
!!! warning
|
||||||
|
Avoid using `lazy="selectin"` on model relationships. It fires silently on every query, cannot be disabled per-call, and can cause unexpected cascading loads through deep relationship chains. Use `default_load_options` instead.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
ArticleCrud = CrudFactory(
|
||||||
|
model=Article,
|
||||||
|
default_load_options=[
|
||||||
|
selectinload(Article.category),
|
||||||
|
selectinload(Article.tags),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
`default_load_options` applies automatically to all read operations (`get`, `first`, `get_multi`, `offset_paginate`, `cursor_paginate`). When `load_options` is passed at call-site, it **fully replaces** `default_load_options` for that query — giving you precise per-call control:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Only loads category, tags are not loaded
|
||||||
|
article = await ArticleCrud.get(
|
||||||
|
session=session,
|
||||||
|
filters=[Article.id == article_id],
|
||||||
|
load_options=[selectinload(Article.category)],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Loads nothing — useful for write-then-refresh flows or lightweight checks
|
||||||
|
articles = await ArticleCrud.get_multi(session=session, load_options=[])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Many-to-many relationships
|
||||||
|
|
||||||
|
Use `m2m_fields` to map schema fields containing lists of IDs to SQLAlchemy relationships. The CRUD class resolves and validates all IDs before persisting:
|
||||||
|
|
||||||
|
```python
|
||||||
|
PostCrud = CrudFactory(
|
||||||
|
model=Post,
|
||||||
|
m2m_fields={"tag_ids": Post.tags},
|
||||||
|
)
|
||||||
|
|
||||||
|
post = await PostCrud.create(session=session, obj=PostCreateSchema(title="Hello", tag_ids=[1, 2, 3]))
|
||||||
|
```
|
||||||
|
|
||||||
|
## Upsert
|
||||||
|
|
||||||
|
Atomic `INSERT ... ON CONFLICT DO UPDATE` using PostgreSQL:
|
||||||
|
|
||||||
|
```python
|
||||||
|
await UserCrud.upsert(
|
||||||
|
session=session,
|
||||||
|
obj=UserCreateSchema(email="alice@example.com", username="alice"),
|
||||||
|
index_elements=[User.email],
|
||||||
|
set_={"username"},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## `schema` — typed response serialization
|
||||||
|
|
||||||
|
!!! info "Added in `v1.1`"
|
||||||
|
|
||||||
|
Pass a Pydantic schema class to `create`, `get`, `update`, or `offset_paginate` to serialize the result directly into that schema and wrap it in a [`Response[schema]`](../reference/schemas.md#fastapi_toolsets.schemas.Response) or [`PaginatedResponse[schema]`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse):
|
||||||
|
|
||||||
|
```python
|
||||||
|
class UserRead(PydanticBase):
|
||||||
|
id: UUID
|
||||||
|
username: str
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/{uuid}",
|
||||||
|
responses=generate_error_responses(NotFoundError),
|
||||||
|
)
|
||||||
|
async def get_user(session: SessionDep, uuid: UUID) -> Response[UserRead]:
|
||||||
|
return await crud.UserCrud.get(
|
||||||
|
session=session,
|
||||||
|
filters=[User.id == uuid],
|
||||||
|
schema=UserRead,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def list_users(session: SessionDep, page: int = 1) -> PaginatedResponse[UserRead]:
|
||||||
|
return await crud.UserCrud.offset_paginate(
|
||||||
|
session=session,
|
||||||
|
page=page,
|
||||||
|
schema=UserRead,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
The schema must have `from_attributes=True` (or inherit from [`PydanticBase`](../reference/schemas.md#fastapi_toolsets.schemas.PydanticBase)) so it can be built from SQLAlchemy model instances.
|
||||||
|
|
||||||
|
!!! warning "Deprecated: `as_response`"
|
||||||
|
The `as_response=True` parameter is **deprecated** and will be removed in **v2.0**. Replace it with `schema=YourSchema`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/crud.md)
|
||||||
92
docs/module/db.md
Normal file
92
docs/module/db.md
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
# DB
|
||||||
|
|
||||||
|
SQLAlchemy async session management with transactions, table locking, and row-change polling.
|
||||||
|
|
||||||
|
!!! info
|
||||||
|
This module has been coded and tested to be compatible with PostgreSQL only.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `db` module provides helpers to create FastAPI dependencies and context managers for `AsyncSession`, along with utilities for nested transactions, table lock and polling for row changes.
|
||||||
|
|
||||||
|
## Session dependency
|
||||||
|
|
||||||
|
Use [`create_db_dependency`](../reference/db.md#fastapi_toolsets.db.create_db_dependency) to create a FastAPI dependency that yields a session and auto-commits on success:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||||
|
from fastapi_toolsets.db import create_db_dependency
|
||||||
|
|
||||||
|
engine = create_async_engine(url="postgresql+asyncpg://...", future=True)
|
||||||
|
session_maker = async_sessionmaker(bind=engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
get_db = create_db_dependency(session_maker=session_maker)
|
||||||
|
|
||||||
|
@router.get("/users")
|
||||||
|
async def list_users(session: AsyncSession = Depends(get_db)):
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Session context manager
|
||||||
|
|
||||||
|
Use [`create_db_context`](../reference/db.md#fastapi_toolsets.db.create_db_context) for sessions outside request handlers (e.g. background tasks, CLI commands):
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import create_db_context
|
||||||
|
|
||||||
|
db_context = create_db_context(session_maker=session_maker)
|
||||||
|
|
||||||
|
async def seed():
|
||||||
|
async with db_context() as session:
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Nested transactions
|
||||||
|
|
||||||
|
[`get_transaction`](../reference/db.md#fastapi_toolsets.db.get_transaction) handles savepoints automatically, allowing safe nesting:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import get_transaction
|
||||||
|
|
||||||
|
async def create_user_with_role(session=session):
|
||||||
|
async with get_transaction(session=session):
|
||||||
|
...
|
||||||
|
async with get_transaction(session=session): # uses savepoint
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Table locking
|
||||||
|
|
||||||
|
[`lock_tables`](../reference/db.md#fastapi_toolsets.db.lock_tables) acquires PostgreSQL table-level locks before executing critical sections:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import lock_tables
|
||||||
|
|
||||||
|
async with lock_tables(session=session, tables=[User], mode="EXCLUSIVE"):
|
||||||
|
# No other transaction can modify User until this block exits
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
Available lock modes are defined in [`LockMode`](../reference/db.md#fastapi_toolsets.db.LockMode): `ACCESS_SHARE`, `ROW_SHARE`, `ROW_EXCLUSIVE`, `SHARE_UPDATE_EXCLUSIVE`, `SHARE`, `SHARE_ROW_EXCLUSIVE`, `EXCLUSIVE`, `ACCESS_EXCLUSIVE`.
|
||||||
|
|
||||||
|
## Row-change polling
|
||||||
|
|
||||||
|
[`wait_for_row_change`](../reference/db.md#fastapi_toolsets.db.wait_for_row_change) polls a row until a specific column changes value, useful for waiting on async side effects:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import wait_for_row_change
|
||||||
|
|
||||||
|
# Wait up to 30s for order.status to change
|
||||||
|
await wait_for_row_change(
|
||||||
|
session=session,
|
||||||
|
model=Order,
|
||||||
|
pk_value=order_id,
|
||||||
|
columns=[Order.status],
|
||||||
|
interval=1.0,
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/db.md)
|
||||||
50
docs/module/dependencies.md
Normal file
50
docs/module/dependencies.md
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
# Dependencies
|
||||||
|
|
||||||
|
FastAPI dependency factories for automatic model resolution from path and body parameters.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `dependencies` module provides two factory functions that create FastAPI dependencies to fetch a model instance from the database automatically — either from a path parameter or from a request body field — and inject it directly into your route handler.
|
||||||
|
|
||||||
|
## `PathDependency`
|
||||||
|
|
||||||
|
[`PathDependency`](../reference/dependencies.md#fastapi_toolsets.dependencies.PathDependency) resolves a model from a URL path parameter and injects it into the route handler. Raises [`NotFoundError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NotFoundError) automatically if the record does not exist.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.dependencies import PathDependency
|
||||||
|
|
||||||
|
UserDep = PathDependency(model=User, field=User.id, session_dep=get_db)
|
||||||
|
|
||||||
|
@router.get("/users/{user_id}")
|
||||||
|
async def get_user(user: User = UserDep):
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
By default the parameter name is inferred from the field (`user_id` for `User.id`). You can override it:
|
||||||
|
|
||||||
|
```python
|
||||||
|
UserDep = PathDependency(model=User, field=User.id, session_dep=get_db, param_name="id")
|
||||||
|
|
||||||
|
@router.get("/users/{id}")
|
||||||
|
async def get_user(user: User = UserDep):
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
## `BodyDependency`
|
||||||
|
|
||||||
|
[`BodyDependency`](../reference/dependencies.md#fastapi_toolsets.dependencies.BodyDependency) resolves a model from a field in the request body. Useful when a body contains a foreign key and you want the full object injected:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.dependencies import BodyDependency
|
||||||
|
|
||||||
|
RoleDep = BodyDependency(model=Role, field=Role.id, session_dep=get_db, body_field="role_id")
|
||||||
|
|
||||||
|
@router.post("/users")
|
||||||
|
async def create_user(body: UserCreateSchema, role: Role = RoleDep):
|
||||||
|
user = User(username=body.username, role=role)
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/dependencies.md)
|
||||||
85
docs/module/exceptions.md
Normal file
85
docs/module/exceptions.md
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
# Exceptions
|
||||||
|
|
||||||
|
Structured API exceptions with consistent error responses and automatic OpenAPI documentation.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `exceptions` module provides a set of pre-built HTTP exceptions and a FastAPI exception handler that formats all errors — including validation errors — into a uniform [`ErrorResponse`](../reference/schemas.md#fastapi_toolsets.schemas.ErrorResponse).
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
Register the exception handlers on your FastAPI app at startup:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app=app)
|
||||||
|
```
|
||||||
|
|
||||||
|
This registers handlers for:
|
||||||
|
|
||||||
|
- [`ApiException`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ApiException) — all custom exceptions below
|
||||||
|
- `RequestValidationError` — Pydantic request validation (422)
|
||||||
|
- `ResponseValidationError` — Pydantic response validation (422)
|
||||||
|
- `Exception` — unhandled errors (500)
|
||||||
|
|
||||||
|
## Built-in exceptions
|
||||||
|
|
||||||
|
| Exception | Status | Default message |
|
||||||
|
|-----------|--------|-----------------|
|
||||||
|
| [`UnauthorizedError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.UnauthorizedError) | 401 | Unauthorized |
|
||||||
|
| [`ForbiddenError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ForbiddenError) | 403 | Forbidden |
|
||||||
|
| [`NotFoundError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NotFoundError) | 404 | Not found |
|
||||||
|
| [`ConflictError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ConflictError) | 409 | Conflict |
|
||||||
|
| [`NoSearchableFieldsError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NoSearchableFieldsError) | 400 | No searchable fields |
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.exceptions import NotFoundError
|
||||||
|
|
||||||
|
@router.get("/users/{id}")
|
||||||
|
async def get_user(id: int, session: AsyncSession = Depends(get_db)):
|
||||||
|
user = await UserCrud.first(session=session, filters=[User.id == id])
|
||||||
|
if not user:
|
||||||
|
raise NotFoundError
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
## Custom exceptions
|
||||||
|
|
||||||
|
Subclass [`ApiException`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ApiException) and define an `api_error` class variable:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.exceptions import ApiException
|
||||||
|
from fastapi_toolsets.schemas import ApiError
|
||||||
|
|
||||||
|
class PaymentRequiredError(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=402,
|
||||||
|
msg="Payment required",
|
||||||
|
desc="Your subscription has expired.",
|
||||||
|
err_code="PAYMENT_REQUIRED",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## OpenAPI response documentation
|
||||||
|
|
||||||
|
Use [`generate_error_responses`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.generate_error_responses) to add error schemas to your endpoint's OpenAPI spec:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.exceptions import generate_error_responses, NotFoundError, ForbiddenError
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/users/{id}",
|
||||||
|
responses=generate_error_responses(NotFoundError, ForbiddenError),
|
||||||
|
)
|
||||||
|
async def get_user(...): ...
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! info
|
||||||
|
The pydantic validation error is automatically added by FastAPI.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/exceptions.md)
|
||||||
123
docs/module/fixtures.md
Normal file
123
docs/module/fixtures.md
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
# Fixtures
|
||||||
|
|
||||||
|
Dependency-aware database seeding with context-based loading strategies.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `fixtures` module lets you define named fixtures with dependencies between them, then load them into the database in the correct order. Fixtures can be scoped to contexts (e.g. base data, testing data) so that only the relevant ones are loaded for each environment.
|
||||||
|
|
||||||
|
## Defining fixtures
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.fixtures import FixtureRegistry, Context
|
||||||
|
|
||||||
|
fixtures = FixtureRegistry()
|
||||||
|
|
||||||
|
@fixtures.register
|
||||||
|
def roles():
|
||||||
|
return [
|
||||||
|
Role(id=1, name="admin"),
|
||||||
|
Role(id=2, name="user"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@fixtures.register(depends_on=["roles"], contexts=[Context.TESTING])
|
||||||
|
def test_users():
|
||||||
|
return [
|
||||||
|
User(id=1, username="alice", role_id=1),
|
||||||
|
User(id=2, username="bob", role_id=2),
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
Dependencies declared via `depends_on` are resolved topologically — `roles` will always be loaded before `test_users`.
|
||||||
|
|
||||||
|
## Loading fixtures
|
||||||
|
|
||||||
|
By context with [`load_fixtures_by_context`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.load_fixtures_by_context):
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.fixtures import load_fixtures_by_context
|
||||||
|
|
||||||
|
async with db_context() as session:
|
||||||
|
await load_fixtures_by_context(session=session, registry=fixtures, context=Context.TESTING)
|
||||||
|
```
|
||||||
|
|
||||||
|
Directly with [`load_fixtures`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.load_fixtures):
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.fixtures import load_fixtures
|
||||||
|
|
||||||
|
async with db_context() as session:
|
||||||
|
await load_fixtures(session=session, registry=fixtures)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Contexts
|
||||||
|
|
||||||
|
[`Context`](../reference/fixtures.md#fastapi_toolsets.fixtures.enum.Context) is an enum with predefined values:
|
||||||
|
|
||||||
|
| Context | Description |
|
||||||
|
|---------|-------------|
|
||||||
|
| `Context.BASE` | Core data required in all environments |
|
||||||
|
| `Context.TESTING` | Data only loaded during tests |
|
||||||
|
| `Context.PRODUCTION` | Data only loaded in production |
|
||||||
|
|
||||||
|
A fixture with no `contexts` defined takes `Context.BASE` by default.
|
||||||
|
|
||||||
|
## Load strategies
|
||||||
|
|
||||||
|
[`LoadStrategy`](../reference/fixtures.md#fastapi_toolsets.fixtures.enum.LoadStrategy) controls how the fixture loader handles rows that already exist:
|
||||||
|
|
||||||
|
| Strategy | Description |
|
||||||
|
|----------|-------------|
|
||||||
|
| `LoadStrategy.INSERT` | Insert only, fail on duplicates |
|
||||||
|
| `LoadStrategy.UPSERT` | Insert or update on conflict |
|
||||||
|
| `LoadStrategy.SKIP` | Skip rows that already exist |
|
||||||
|
|
||||||
|
## Merging registries
|
||||||
|
|
||||||
|
Split fixtures definitions across modules and merge them:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from myapp.fixtures.dev import dev_fixtures
|
||||||
|
from myapp.fixtures.prod import prod_fixtures
|
||||||
|
|
||||||
|
fixtures = fixturesRegistry()
|
||||||
|
fixtures.include_registry(registry=dev_fixtures)
|
||||||
|
fixtures.include_registry(registry=prod_fixtures)
|
||||||
|
|
||||||
|
## Pytest integration
|
||||||
|
|
||||||
|
Use [`register_fixtures`](../reference/pytest.md#fastapi_toolsets.pytest.plugin.register_fixtures) to expose each fixture in your registry as an injectable pytest fixture named `fixture_{name}` by default:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# conftest.py
|
||||||
|
import pytest
|
||||||
|
from fastapi_toolsets.pytest import create_db_session, register_fixtures
|
||||||
|
from app.fixtures import registry
|
||||||
|
from app.models import Base
|
||||||
|
|
||||||
|
DATABASE_URL = "postgresql+asyncpg://user:pass@localhost/test_db"
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session():
|
||||||
|
async with create_db_session(database_url=DATABASE_URL, base=Base, cleanup=True) as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
register_fixtures(registry=registry, namespace=globals())
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# test_users.py
|
||||||
|
async def test_user_can_login(fixture_users: list[User], fixture_roles: list[Role]):
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
The load order is resolved automatically from the `depends_on` declarations in your registry. Each generated fixture receives `db_session` as a dependency and returns the list of loaded model instances.
|
||||||
|
|
||||||
|
## CLI integration
|
||||||
|
|
||||||
|
Fixtures can be triggered from the CLI. See the [CLI module](cli.md) for setup instructions.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/fixtures.md)
|
||||||
39
docs/module/logger.md
Normal file
39
docs/module/logger.md
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# Logger
|
||||||
|
|
||||||
|
Lightweight logging utilities with consistent formatting and uvicorn integration.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `logger` module provides two helpers: one to configure the root logger (and uvicorn loggers) at startup, and one to retrieve a named logger anywhere in your codebase.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
Call [`configure_logging`](../reference/logger.md#fastapi_toolsets.logger.configure_logging) once at application startup:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.logger import configure_logging
|
||||||
|
|
||||||
|
configure_logging(level="INFO")
|
||||||
|
```
|
||||||
|
|
||||||
|
This sets up a stdout handler with a consistent format and also configures uvicorn's access and error loggers so all log output shares the same style.
|
||||||
|
|
||||||
|
## Getting a logger
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__)
|
||||||
|
logger.info("User created")
|
||||||
|
```
|
||||||
|
|
||||||
|
When called without arguments, [`get_logger`](../reference/logger.md#fastapi_toolsets.logger.get_logger) auto-detects the caller's module name via frame inspection:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Equivalent to get_logger(name=__name__)
|
||||||
|
logger = get_logger()
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/logger.md)
|
||||||
86
docs/module/metrics.md
Normal file
86
docs/module/metrics.md
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
# Metrics
|
||||||
|
|
||||||
|
Prometheus metrics integration with a decorator-based registry and multi-process support.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
=== "uv"
|
||||||
|
``` bash
|
||||||
|
uv add "fastapi-toolsets[metrics]"
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "pip"
|
||||||
|
``` bash
|
||||||
|
pip install "fastapi-toolsets[metrics]"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `metrics` module provides a [`MetricsRegistry`](../reference/metrics.md#fastapi_toolsets.metrics.registry.MetricsRegistry) to declare Prometheus metrics with decorators, and an [`init_metrics`](../reference/metrics.md#fastapi_toolsets.metrics.handler.init_metrics) function to mount a `/metrics` endpoint on your FastAPI app.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi_toolsets.metrics import MetricsRegistry, init_metrics
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
metrics = MetricsRegistry()
|
||||||
|
|
||||||
|
init_metrics(app=app, registry=metrics)
|
||||||
|
```
|
||||||
|
|
||||||
|
This mounts the `/metrics` endpoint that Prometheus can scrape.
|
||||||
|
|
||||||
|
## Declaring metrics
|
||||||
|
|
||||||
|
### Providers
|
||||||
|
|
||||||
|
Providers are called once at startup and register metrics that are updated externally (e.g. counters, histograms):
|
||||||
|
|
||||||
|
```python
|
||||||
|
from prometheus_client import Counter, Histogram
|
||||||
|
|
||||||
|
@metrics.register
|
||||||
|
def http_requests():
|
||||||
|
return Counter("http_requests_total", "Total HTTP requests", ["method", "status"])
|
||||||
|
|
||||||
|
@metrics.register
|
||||||
|
def request_duration():
|
||||||
|
return Histogram("request_duration_seconds", "Request duration")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Collectors
|
||||||
|
|
||||||
|
Collectors are called on every scrape. Use them for metrics that reflect current state (e.g. gauges):
|
||||||
|
|
||||||
|
```python
|
||||||
|
@metrics.register(collect=True)
|
||||||
|
def queue_depth():
|
||||||
|
gauge = Gauge("queue_depth", "Current queue depth")
|
||||||
|
gauge.set(get_current_queue_depth())
|
||||||
|
```
|
||||||
|
|
||||||
|
## Merging registries
|
||||||
|
|
||||||
|
Split metrics definitions across modules and merge them:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from myapp.metrics.http import http_metrics
|
||||||
|
from myapp.metrics.db import db_metrics
|
||||||
|
|
||||||
|
metrics = MetricsRegistry()
|
||||||
|
metrics.include_registry(registry=http_metrics)
|
||||||
|
metrics.include_registry(registry=db_metrics)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Multi-process mode
|
||||||
|
|
||||||
|
Multi-process support is enabled automatically when the `PROMETHEUS_MULTIPROC_DIR` environment variable is set. No code changes are required.
|
||||||
|
|
||||||
|
!!! warning "Environment variable name"
|
||||||
|
The correct variable is `PROMETHEUS_MULTIPROC_DIR` (not `PROMETHEUS_MULTIPROCESS_DIR`).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/metrics.md)
|
||||||
86
docs/module/pytest.md
Normal file
86
docs/module/pytest.md
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
# Pytest
|
||||||
|
|
||||||
|
Testing helpers for FastAPI applications with async client, database sessions, and parallel worker support.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
=== "uv"
|
||||||
|
``` bash
|
||||||
|
uv add "fastapi-toolsets[pytest]"
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "pip"
|
||||||
|
``` bash
|
||||||
|
pip install "fastapi-toolsets[pytest]"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `pytest` module provides utilities for setting up async test clients, managing test database sessions, and supporting parallel test execution with `pytest-xdist`.
|
||||||
|
|
||||||
|
## Creating an async client
|
||||||
|
|
||||||
|
Use [`create_async_client`](../reference/pytest.md#fastapi_toolsets.pytest.utils.create_async_client) to get an `httpx.AsyncClient` configured for your FastAPI app:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import create_async_client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def http_client(db_session):
|
||||||
|
async def _override_get_db():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
async with create_async_client(
|
||||||
|
app=app,
|
||||||
|
base_url="http://127.0.0.1/api/v1",
|
||||||
|
dependency_overrides={get_db: _override_get_db},
|
||||||
|
) as c:
|
||||||
|
yield c
|
||||||
|
```
|
||||||
|
|
||||||
|
## Database sessions in tests
|
||||||
|
|
||||||
|
Use [`create_db_session`](../reference/pytest.md#fastapi_toolsets.pytest.utils.create_db_session) to create an isolated `AsyncSession` for a test:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import create_db_session, create_worker_database
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
async def worker_db_url():
|
||||||
|
async with create_worker_database(
|
||||||
|
database_url=str(settings.SQLALCHEMY_DATABASE_URI)
|
||||||
|
) as url:
|
||||||
|
yield url
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session(worker_db_url):
|
||||||
|
async with create_db_session(
|
||||||
|
database_url=worker_db_url, base=Base, cleanup=True
|
||||||
|
) as session:
|
||||||
|
yield session
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! info
|
||||||
|
In this example, the database is reset between each test using the argument `cleanup=True`.
|
||||||
|
|
||||||
|
## Parallel testing with pytest-xdist
|
||||||
|
|
||||||
|
The examples above are already compatible with parallel test execution with `pytest-xdist`.
|
||||||
|
|
||||||
|
## Cleaning up tables
|
||||||
|
|
||||||
|
If you want to manually clean up a database you can use [`cleanup_tables`](../reference/pytest.md#fastapi_toolsets.pytest.utils.cleanup_tables), this will truncates all tables between tests for fast isolation:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import cleanup_tables
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
async def clean(db_session):
|
||||||
|
yield
|
||||||
|
await cleanup_tables(session=db_session, base=Base)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/pytest.md)
|
||||||
49
docs/module/schemas.md
Normal file
49
docs/module/schemas.md
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
# Schemas
|
||||||
|
|
||||||
|
Standardized Pydantic response models for consistent API responses across your FastAPI application.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `schemas` module provides generic response wrappers that enforce a uniform response structure. All models use `from_attributes=True` for ORM compatibility and `validate_assignment=True` for runtime type safety.
|
||||||
|
|
||||||
|
## Response models
|
||||||
|
|
||||||
|
### [`Response[T]`](../reference/schemas.md#fastapi_toolsets.schemas.Response)
|
||||||
|
|
||||||
|
The most common wrapper for a single resource response.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.schemas import Response
|
||||||
|
|
||||||
|
@router.get("/users/{id}")
|
||||||
|
async def get_user(user: User = UserDep) -> Response[UserSchema]:
|
||||||
|
return Response(data=user, message="User retrieved")
|
||||||
|
```
|
||||||
|
|
||||||
|
### [`PaginatedResponse[T]`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse)
|
||||||
|
|
||||||
|
Wraps a list of items with pagination metadata.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.schemas import PaginatedResponse, Pagination
|
||||||
|
|
||||||
|
@router.get("/users")
|
||||||
|
async def list_users() -> PaginatedResponse[UserSchema]:
|
||||||
|
return PaginatedResponse(
|
||||||
|
data=users,
|
||||||
|
pagination=Pagination(
|
||||||
|
total_count=100,
|
||||||
|
items_per_page=10,
|
||||||
|
page=1,
|
||||||
|
has_more=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### [`ErrorResponse`](../reference/schemas.md#fastapi_toolsets.schemas.ErrorResponse)
|
||||||
|
|
||||||
|
Returned automatically by the exceptions handler.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/schemas.md)
|
||||||
7
docs/overrides/main.html
Normal file
7
docs/overrides/main.html
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
{% extends "base.html" %} {% block extrahead %}
|
||||||
|
<script
|
||||||
|
defer
|
||||||
|
src="https://analytics.d3vyce.fr/script.js"
|
||||||
|
data-website-id="338b8816-7b99-4c6a-82f3-15595be3fd47"
|
||||||
|
></script>
|
||||||
|
{{ super() }} {% endblock %}
|
||||||
27
docs/reference/cli.md
Normal file
27
docs/reference/cli.md
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
# `cli`
|
||||||
|
|
||||||
|
Here's the reference for the CLI configuration helpers used to load settings from `pyproject.toml`.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.cli.config`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.cli.config import (
|
||||||
|
import_from_string,
|
||||||
|
get_config_value,
|
||||||
|
get_fixtures_registry,
|
||||||
|
get_db_context,
|
||||||
|
get_custom_cli,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.config.import_from_string
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.config.get_config_value
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.config.get_fixtures_registry
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.config.get_db_context
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.config.get_custom_cli
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.utils.async_command
|
||||||
20
docs/reference/crud.md
Normal file
20
docs/reference/crud.md
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
# `crud`
|
||||||
|
|
||||||
|
Here's the reference for the CRUD classes, factory, and search utilities.
|
||||||
|
|
||||||
|
You can import the main symbols from `fastapi_toolsets.crud`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.crud import CrudFactory, AsyncCrud
|
||||||
|
from fastapi_toolsets.crud.search import SearchConfig, get_searchable_fields, build_search_filters
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.crud.factory.AsyncCrud
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.crud.factory.CrudFactory
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.crud.search.SearchConfig
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.crud.search.get_searchable_fields
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.crud.search.build_search_filters
|
||||||
28
docs/reference/db.md
Normal file
28
docs/reference/db.md
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# `db`
|
||||||
|
|
||||||
|
Here's the reference for all database session utilities, transaction helpers, and locking functions.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.db`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import (
|
||||||
|
LockMode,
|
||||||
|
create_db_dependency,
|
||||||
|
create_db_context,
|
||||||
|
get_transaction,
|
||||||
|
lock_tables,
|
||||||
|
wait_for_row_change,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.LockMode
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.create_db_dependency
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.create_db_context
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.get_transaction
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.lock_tables
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.wait_for_row_change
|
||||||
13
docs/reference/dependencies.md
Normal file
13
docs/reference/dependencies.md
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# `dependencies`
|
||||||
|
|
||||||
|
Here's the reference for the FastAPI dependency factory functions.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.dependencies`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.dependencies import PathDependency, BodyDependency
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.dependencies.PathDependency
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.dependencies.BodyDependency
|
||||||
34
docs/reference/exceptions.md
Normal file
34
docs/reference/exceptions.md
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
# `exceptions`
|
||||||
|
|
||||||
|
Here's the reference for all exception classes and handler utilities.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.exceptions`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.exceptions import (
|
||||||
|
ApiException,
|
||||||
|
UnauthorizedError,
|
||||||
|
ForbiddenError,
|
||||||
|
NotFoundError,
|
||||||
|
ConflictError,
|
||||||
|
NoSearchableFieldsError,
|
||||||
|
generate_error_responses,
|
||||||
|
init_exceptions_handlers,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.ApiException
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.UnauthorizedError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.ForbiddenError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.NotFoundError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.ConflictError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.NoSearchableFieldsError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.generate_error_responses
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.handler.init_exceptions_handlers
|
||||||
31
docs/reference/fixtures.md
Normal file
31
docs/reference/fixtures.md
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
# `fixtures`
|
||||||
|
|
||||||
|
Here's the reference for the fixture registry, enums, and loading utilities.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.fixtures`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.fixtures import (
|
||||||
|
Context,
|
||||||
|
LoadStrategy,
|
||||||
|
Fixture,
|
||||||
|
FixtureRegistry,
|
||||||
|
load_fixtures,
|
||||||
|
load_fixtures_by_context,
|
||||||
|
get_obj_by_attr,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.enum.Context
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.enum.LoadStrategy
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.registry.Fixture
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.registry.FixtureRegistry
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.utils.load_fixtures
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.utils.load_fixtures_by_context
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.utils.get_obj_by_attr
|
||||||
13
docs/reference/logger.md
Normal file
13
docs/reference/logger.md
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# `logger`
|
||||||
|
|
||||||
|
Here's the reference for the logging utilities.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.logger`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.logger import configure_logging, get_logger
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.logger.configure_logging
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.logger.get_logger
|
||||||
15
docs/reference/metrics.md
Normal file
15
docs/reference/metrics.md
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# `metrics`
|
||||||
|
|
||||||
|
Here's the reference for the Prometheus metrics registry and endpoint handler.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.metrics`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.metrics import Metric, MetricsRegistry, init_metrics
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.metrics.registry.Metric
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.metrics.registry.MetricsRegistry
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.metrics.handler.init_metrics
|
||||||
28
docs/reference/pytest.md
Normal file
28
docs/reference/pytest.md
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# `pytest`
|
||||||
|
|
||||||
|
Here's the reference for all testing utilities and pytest fixtures.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.pytest`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import (
|
||||||
|
register_fixtures,
|
||||||
|
create_async_client,
|
||||||
|
create_db_session,
|
||||||
|
worker_database_url,
|
||||||
|
create_worker_database,
|
||||||
|
cleanup_tables,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.plugin.register_fixtures
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.utils.create_async_client
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.utils.create_db_session
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.utils.worker_database_url
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.utils.create_worker_database
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.utils.cleanup_tables
|
||||||
34
docs/reference/schemas.md
Normal file
34
docs/reference/schemas.md
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
# `schemas` module
|
||||||
|
|
||||||
|
Here's the reference for all response models and types provided by the `schemas` module.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.schemas`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.schemas import (
|
||||||
|
PydanticBase,
|
||||||
|
ResponseStatus,
|
||||||
|
ApiError,
|
||||||
|
BaseResponse,
|
||||||
|
Response,
|
||||||
|
ErrorResponse,
|
||||||
|
Pagination,
|
||||||
|
PaginatedResponse,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.PydanticBase
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.ResponseStatus
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.ApiError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.BaseResponse
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.Response
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.ErrorResponse
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.Pagination
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.PaginatedResponse
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "fastapi-toolsets"
|
name = "fastapi-toolsets"
|
||||||
version = "0.2.0"
|
version = "1.1.1"
|
||||||
description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL"
|
description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
@@ -24,18 +24,17 @@ classifiers = [
|
|||||||
"Programming Language :: Python :: 3.11",
|
"Programming Language :: Python :: 3.11",
|
||||||
"Programming Language :: Python :: 3.12",
|
"Programming Language :: Python :: 3.12",
|
||||||
"Programming Language :: Python :: 3.13",
|
"Programming Language :: Python :: 3.13",
|
||||||
|
"Programming Language :: Python :: 3.14",
|
||||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||||
"Topic :: Software Development :: Libraries",
|
"Topic :: Software Development :: Libraries",
|
||||||
"Topic :: Software Development",
|
"Topic :: Software Development",
|
||||||
"Typing :: Typed",
|
"Typing :: Typed",
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"fastapi>=0.100.0",
|
|
||||||
"sqlalchemy[asyncio]>=2.0",
|
|
||||||
"asyncpg>=0.29.0",
|
"asyncpg>=0.29.0",
|
||||||
|
"fastapi>=0.100.0",
|
||||||
"pydantic>=2.0",
|
"pydantic>=2.0",
|
||||||
"typer>=0.9.0",
|
"sqlalchemy[asyncio]>=2.0",
|
||||||
"httpx>=0.25.0",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
@@ -45,23 +44,47 @@ Repository = "https://github.com/d3vyce/fastapi-toolsets"
|
|||||||
Issues = "https://github.com/d3vyce/fastapi-toolsets/issues"
|
Issues = "https://github.com/d3vyce/fastapi-toolsets/issues"
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
test = [
|
cli = [
|
||||||
"pytest>=8.0.0",
|
"typer>=0.9.0",
|
||||||
"pytest-anyio>=0.0.0",
|
|
||||||
"coverage>=7.0.0",
|
|
||||||
"pytest-cov>=4.0.0",
|
|
||||||
]
|
]
|
||||||
dev = [
|
metrics = [
|
||||||
"fastapi-toolsets[test]",
|
"prometheus_client>=0.20.0",
|
||||||
"ruff>=0.1.0",
|
]
|
||||||
"ty>=0.0.1a0",
|
pytest = [
|
||||||
|
"httpx>=0.25.0",
|
||||||
|
"pytest-xdist>=3.0.0",
|
||||||
|
"pytest>=8.0.0",
|
||||||
|
]
|
||||||
|
all = [
|
||||||
|
"fastapi-toolsets[cli,metrics,pytest]",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
fastapi-toolsets = "fastapi_toolsets.cli:app"
|
manager = "fastapi_toolsets.cli.app:cli"
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
dev = [
|
||||||
|
{include-group = "tests"},
|
||||||
|
{include-group = "docs"},
|
||||||
|
"fastapi-toolsets[all]",
|
||||||
|
"ruff>=0.1.0",
|
||||||
|
"ty>=0.0.1a0",
|
||||||
|
]
|
||||||
|
tests = [
|
||||||
|
"coverage>=7.0.0",
|
||||||
|
"httpx>=0.25.0",
|
||||||
|
"pytest-anyio>=0.0.0",
|
||||||
|
"pytest-cov>=4.0.0",
|
||||||
|
"pytest-xdist>=3.0.0",
|
||||||
|
"pytest>=8.0.0",
|
||||||
|
]
|
||||||
|
docs = [
|
||||||
|
"mkdocstrings-python>=2.0.2",
|
||||||
|
"zensical>=0.0.23",
|
||||||
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["uv_build>=0.9.26,<0.10.0"]
|
requires = ["uv_build>=0.10,<0.11.0"]
|
||||||
build-backend = "uv_build"
|
build-backend = "uv_build"
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
|
|||||||
@@ -21,4 +21,4 @@ Example usage:
|
|||||||
return Response(data={"user": user.username}, message="Success")
|
return Response(data={"user": user.username}, message="Success")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.2.0"
|
__version__ = "1.1.1"
|
||||||
|
|||||||
9
src/fastapi_toolsets/_imports.py
Normal file
9
src/fastapi_toolsets/_imports.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
"""Optional dependency helpers."""
|
||||||
|
|
||||||
|
|
||||||
|
def require_extra(package: str, extra: str) -> None:
|
||||||
|
"""Raise *ImportError* with an actionable install instruction."""
|
||||||
|
raise ImportError(
|
||||||
|
f"'{package}' is required to use this module. "
|
||||||
|
f"Install it with: pip install fastapi-toolsets[{extra}]"
|
||||||
|
)
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
"""CLI for FastAPI projects."""
|
"""CLI for FastAPI projects."""
|
||||||
|
|
||||||
from .app import app, register_command
|
from .utils import async_command
|
||||||
|
|
||||||
__all__ = ["app", "register_command"]
|
__all__ = ["async_command"]
|
||||||
|
|||||||
@@ -1,97 +1,37 @@
|
|||||||
"""Main CLI application."""
|
"""Main CLI application."""
|
||||||
|
|
||||||
import importlib.util
|
try:
|
||||||
import sys
|
import typer
|
||||||
from pathlib import Path
|
except ImportError:
|
||||||
from typing import Annotated
|
from .._imports import require_extra
|
||||||
|
|
||||||
import typer
|
require_extra(package="typer", extra="cli")
|
||||||
|
|
||||||
from .commands import fixtures
|
from ..logger import configure_logging
|
||||||
|
from .config import get_custom_cli
|
||||||
|
from .pyproject import load_pyproject
|
||||||
|
|
||||||
app = typer.Typer(
|
# Use custom CLI if configured, otherwise create default one
|
||||||
name="fastapi-utils",
|
_custom_cli = get_custom_cli()
|
||||||
|
|
||||||
|
if _custom_cli is not None:
|
||||||
|
cli = _custom_cli
|
||||||
|
else:
|
||||||
|
cli = typer.Typer(
|
||||||
|
name="manager",
|
||||||
help="CLI utilities for FastAPI projects.",
|
help="CLI utilities for FastAPI projects.",
|
||||||
no_args_is_help=True,
|
no_args_is_help=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Register built-in commands
|
_config = load_pyproject()
|
||||||
app.add_typer(fixtures.app, name="fixtures")
|
if _config.get("fixtures") and _config.get("db_context"):
|
||||||
|
from .commands.fixtures import fixture_cli
|
||||||
|
|
||||||
|
cli.add_typer(fixture_cli, name="fixtures")
|
||||||
|
|
||||||
|
|
||||||
def register_command(command: typer.Typer, name: str) -> None:
|
@cli.callback()
|
||||||
"""Register a custom command group.
|
def main(ctx: typer.Context) -> None:
|
||||||
|
|
||||||
Args:
|
|
||||||
command: Typer app for the command group
|
|
||||||
name: Name for the command group
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# In your project's cli.py:
|
|
||||||
import typer
|
|
||||||
from fastapi_toolsets.cli import app, register_command
|
|
||||||
|
|
||||||
my_commands = typer.Typer()
|
|
||||||
|
|
||||||
@my_commands.command()
|
|
||||||
def seed():
|
|
||||||
'''Seed the database.'''
|
|
||||||
...
|
|
||||||
|
|
||||||
register_command(my_commands, "db")
|
|
||||||
# Now available as: fastapi-utils db seed
|
|
||||||
"""
|
|
||||||
app.add_typer(command, name=name)
|
|
||||||
|
|
||||||
|
|
||||||
@app.callback()
|
|
||||||
def main(
|
|
||||||
ctx: typer.Context,
|
|
||||||
config: Annotated[
|
|
||||||
Path | None,
|
|
||||||
typer.Option(
|
|
||||||
"--config",
|
|
||||||
"-c",
|
|
||||||
help="Path to project config file (Python module with fixtures registry).",
|
|
||||||
envvar="FASTAPI_TOOLSETS_CONFIG",
|
|
||||||
),
|
|
||||||
] = None,
|
|
||||||
) -> None:
|
|
||||||
"""FastAPI utilities CLI."""
|
"""FastAPI utilities CLI."""
|
||||||
|
configure_logging()
|
||||||
ctx.ensure_object(dict)
|
ctx.ensure_object(dict)
|
||||||
|
|
||||||
if config:
|
|
||||||
ctx.obj["config_path"] = config
|
|
||||||
# Load the config module
|
|
||||||
config_module = _load_module_from_path(config)
|
|
||||||
ctx.obj["config_module"] = config_module
|
|
||||||
|
|
||||||
|
|
||||||
def _load_module_from_path(path: Path) -> object:
|
|
||||||
"""Load a Python module from a file path.
|
|
||||||
|
|
||||||
Handles both absolute and relative imports by adding the config's
|
|
||||||
parent directory to sys.path temporarily.
|
|
||||||
"""
|
|
||||||
path = path.resolve()
|
|
||||||
|
|
||||||
# Add the parent directory to sys.path to support relative imports
|
|
||||||
parent_dir = str(
|
|
||||||
path.parent.parent
|
|
||||||
) # Go up two levels (e.g., from app/cli_config.py to project root)
|
|
||||||
if parent_dir not in sys.path:
|
|
||||||
sys.path.insert(0, parent_dir)
|
|
||||||
|
|
||||||
# Also add immediate parent for direct module imports
|
|
||||||
immediate_parent = str(path.parent)
|
|
||||||
if immediate_parent not in sys.path:
|
|
||||||
sys.path.insert(0, immediate_parent)
|
|
||||||
|
|
||||||
spec = importlib.util.spec_from_file_location("config", path)
|
|
||||||
if spec is None or spec.loader is None:
|
|
||||||
raise typer.BadParameter(f"Cannot load module from {path}")
|
|
||||||
|
|
||||||
module = importlib.util.module_from_spec(spec)
|
|
||||||
sys.modules["config"] = module
|
|
||||||
spec.loader.exec_module(module)
|
|
||||||
return module
|
|
||||||
|
|||||||
@@ -1,138 +1,66 @@
|
|||||||
"""Fixture management commands."""
|
"""Fixture management commands."""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
from ...fixtures import Context, FixtureRegistry, LoadStrategy, load_fixtures_by_context
|
from ...fixtures import Context, LoadStrategy, load_fixtures_by_context
|
||||||
|
from ..config import get_db_context, get_fixtures_registry
|
||||||
|
from ..utils import async_command
|
||||||
|
|
||||||
app = typer.Typer(
|
fixture_cli = typer.Typer(
|
||||||
name="fixtures",
|
name="fixtures",
|
||||||
help="Manage database fixtures.",
|
help="Manage database fixtures.",
|
||||||
no_args_is_help=True,
|
no_args_is_help=True,
|
||||||
)
|
)
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
def _get_registry(ctx: typer.Context) -> FixtureRegistry:
|
@fixture_cli.command("list")
|
||||||
"""Get fixture registry from context."""
|
|
||||||
config = ctx.obj.get("config_module") if ctx.obj else None
|
|
||||||
if config is None:
|
|
||||||
raise typer.BadParameter(
|
|
||||||
"No config provided. Use --config to specify a config file with a 'fixtures' registry."
|
|
||||||
)
|
|
||||||
|
|
||||||
registry = getattr(config, "fixtures", None)
|
|
||||||
if registry is None:
|
|
||||||
raise typer.BadParameter(
|
|
||||||
"Config module must have a 'fixtures' attribute (FixtureRegistry instance)."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not isinstance(registry, FixtureRegistry):
|
|
||||||
raise typer.BadParameter(
|
|
||||||
f"'fixtures' must be a FixtureRegistry instance, got {type(registry).__name__}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return registry
|
|
||||||
|
|
||||||
|
|
||||||
def _get_db_context(ctx: typer.Context):
|
|
||||||
"""Get database context manager from config."""
|
|
||||||
config = ctx.obj.get("config_module") if ctx.obj else None
|
|
||||||
if config is None:
|
|
||||||
raise typer.BadParameter("No config provided.")
|
|
||||||
|
|
||||||
get_db_context = getattr(config, "get_db_context", None)
|
|
||||||
if get_db_context is None:
|
|
||||||
raise typer.BadParameter("Config module must have a 'get_db_context' function.")
|
|
||||||
|
|
||||||
return get_db_context
|
|
||||||
|
|
||||||
|
|
||||||
@app.command("list")
|
|
||||||
def list_fixtures(
|
def list_fixtures(
|
||||||
ctx: typer.Context,
|
ctx: typer.Context,
|
||||||
context: Annotated[
|
context: Annotated[
|
||||||
str | None,
|
Context | None,
|
||||||
typer.Option(
|
typer.Option(
|
||||||
"--context",
|
"--context",
|
||||||
"-c",
|
"-c",
|
||||||
help="Filter by context (base, production, development, testing).",
|
help="Filter by context.",
|
||||||
),
|
),
|
||||||
] = None,
|
] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""List all registered fixtures."""
|
"""List all registered fixtures."""
|
||||||
registry = _get_registry(ctx)
|
registry = get_fixtures_registry()
|
||||||
|
fixtures = registry.get_by_context(context.value) if context else registry.get_all()
|
||||||
if context:
|
|
||||||
fixtures = registry.get_by_context(context)
|
|
||||||
else:
|
|
||||||
fixtures = registry.get_all()
|
|
||||||
|
|
||||||
if not fixtures:
|
if not fixtures:
|
||||||
typer.echo("No fixtures found.")
|
print("No fixtures found.")
|
||||||
return
|
return
|
||||||
|
|
||||||
typer.echo(f"\n{'Name':<30} {'Contexts':<30} {'Dependencies'}")
|
table = Table("Name", "Contexts", "Dependencies")
|
||||||
typer.echo("-" * 80)
|
|
||||||
|
|
||||||
for fixture in fixtures:
|
for fixture in fixtures:
|
||||||
contexts = ", ".join(fixture.contexts)
|
contexts = ", ".join(fixture.contexts)
|
||||||
deps = ", ".join(fixture.depends_on) if fixture.depends_on else "-"
|
deps = ", ".join(fixture.depends_on) if fixture.depends_on else "-"
|
||||||
typer.echo(f"{fixture.name:<30} {contexts:<30} {deps}")
|
table.add_row(fixture.name, contexts, deps)
|
||||||
|
|
||||||
typer.echo(f"\nTotal: {len(fixtures)} fixture(s)")
|
console.print(table)
|
||||||
|
print(f"\nTotal: {len(fixtures)} fixture(s)")
|
||||||
|
|
||||||
|
|
||||||
@app.command("graph")
|
@fixture_cli.command("load")
|
||||||
def show_graph(
|
@async_command
|
||||||
ctx: typer.Context,
|
async def load(
|
||||||
fixture_name: Annotated[
|
|
||||||
str | None,
|
|
||||||
typer.Argument(help="Show dependencies for a specific fixture."),
|
|
||||||
] = None,
|
|
||||||
) -> None:
|
|
||||||
"""Show fixture dependency graph."""
|
|
||||||
registry = _get_registry(ctx)
|
|
||||||
|
|
||||||
if fixture_name:
|
|
||||||
try:
|
|
||||||
order = registry.resolve_dependencies(fixture_name)
|
|
||||||
typer.echo(f"\nDependency chain for '{fixture_name}':\n")
|
|
||||||
for i, name in enumerate(order):
|
|
||||||
indent = " " * i
|
|
||||||
arrow = "└─> " if i > 0 else ""
|
|
||||||
typer.echo(f"{indent}{arrow}{name}")
|
|
||||||
except KeyError:
|
|
||||||
typer.echo(f"Fixture '{fixture_name}' not found.", err=True)
|
|
||||||
raise typer.Exit(1)
|
|
||||||
else:
|
|
||||||
# Show full graph
|
|
||||||
fixtures = registry.get_all()
|
|
||||||
|
|
||||||
typer.echo("\nFixture Dependency Graph:\n")
|
|
||||||
for fixture in fixtures:
|
|
||||||
deps = (
|
|
||||||
f" -> [{', '.join(fixture.depends_on)}]" if fixture.depends_on else ""
|
|
||||||
)
|
|
||||||
typer.echo(f" {fixture.name}{deps}")
|
|
||||||
|
|
||||||
|
|
||||||
@app.command("load")
|
|
||||||
def load(
|
|
||||||
ctx: typer.Context,
|
ctx: typer.Context,
|
||||||
contexts: Annotated[
|
contexts: Annotated[
|
||||||
list[str] | None,
|
list[Context] | None,
|
||||||
typer.Argument(
|
typer.Argument(help="Contexts to load."),
|
||||||
help="Contexts to load (base, production, development, testing)."
|
|
||||||
),
|
|
||||||
] = None,
|
] = None,
|
||||||
strategy: Annotated[
|
strategy: Annotated[
|
||||||
str,
|
LoadStrategy,
|
||||||
typer.Option(
|
typer.Option("--strategy", "-s", help="Load strategy."),
|
||||||
"--strategy", "-s", help="Load strategy: merge, insert, skip_existing."
|
] = LoadStrategy.MERGE,
|
||||||
),
|
|
||||||
] = "merge",
|
|
||||||
dry_run: Annotated[
|
dry_run: Annotated[
|
||||||
bool,
|
bool,
|
||||||
typer.Option(
|
typer.Option(
|
||||||
@@ -141,85 +69,32 @@ def load(
|
|||||||
] = False,
|
] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Load fixtures into the database."""
|
"""Load fixtures into the database."""
|
||||||
registry = _get_registry(ctx)
|
registry = get_fixtures_registry()
|
||||||
get_db_context = _get_db_context(ctx)
|
db_context = get_db_context()
|
||||||
|
|
||||||
# Parse contexts
|
context_list = [c.value for c in contexts] if contexts else [Context.BASE]
|
||||||
if contexts:
|
|
||||||
context_list = contexts
|
|
||||||
else:
|
|
||||||
context_list = [Context.BASE]
|
|
||||||
|
|
||||||
# Parse strategy
|
|
||||||
try:
|
|
||||||
load_strategy = LoadStrategy(strategy)
|
|
||||||
except ValueError:
|
|
||||||
typer.echo(
|
|
||||||
f"Invalid strategy: {strategy}. Use: merge, insert, skip_existing", err=True
|
|
||||||
)
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
# Resolve what will be loaded
|
|
||||||
ordered = registry.resolve_context_dependencies(*context_list)
|
ordered = registry.resolve_context_dependencies(*context_list)
|
||||||
|
|
||||||
if not ordered:
|
if not ordered:
|
||||||
typer.echo("No fixtures to load for the specified context(s).")
|
print("No fixtures to load for the specified context(s).")
|
||||||
return
|
return
|
||||||
|
|
||||||
typer.echo(f"\nFixtures to load ({load_strategy.value} strategy):")
|
print(f"\nFixtures to load ({strategy.value} strategy):")
|
||||||
for name in ordered:
|
for name in ordered:
|
||||||
fixture = registry.get(name)
|
fixture = registry.get(name)
|
||||||
instances = list(fixture.func())
|
instances = list(fixture.func())
|
||||||
model_name = type(instances[0]).__name__ if instances else "?"
|
model_name = type(instances[0]).__name__ if instances else "?"
|
||||||
typer.echo(f" - {name}: {len(instances)} {model_name}(s)")
|
print(f" - {name}: {len(instances)} {model_name}(s)")
|
||||||
|
|
||||||
if dry_run:
|
if dry_run:
|
||||||
typer.echo("\n[Dry run - no changes made]")
|
print("\n[Dry run - no changes made]")
|
||||||
return
|
return
|
||||||
|
|
||||||
typer.echo("\nLoading...")
|
async with db_context() as session:
|
||||||
|
|
||||||
async def do_load():
|
|
||||||
async with get_db_context() as session:
|
|
||||||
result = await load_fixtures_by_context(
|
result = await load_fixtures_by_context(
|
||||||
session, registry, *context_list, strategy=load_strategy
|
session, registry, *context_list, strategy=strategy
|
||||||
)
|
)
|
||||||
return result
|
|
||||||
|
|
||||||
result = asyncio.run(do_load())
|
|
||||||
|
|
||||||
total = sum(len(items) for items in result.values())
|
total = sum(len(items) for items in result.values())
|
||||||
typer.echo(f"\nLoaded {total} record(s) successfully.")
|
print(f"\nLoaded {total} record(s) successfully.")
|
||||||
|
|
||||||
|
|
||||||
@app.command("show")
|
|
||||||
def show_fixture(
|
|
||||||
ctx: typer.Context,
|
|
||||||
name: Annotated[str, typer.Argument(help="Fixture name to show.")],
|
|
||||||
) -> None:
|
|
||||||
"""Show details of a specific fixture."""
|
|
||||||
registry = _get_registry(ctx)
|
|
||||||
|
|
||||||
try:
|
|
||||||
fixture = registry.get(name)
|
|
||||||
except KeyError:
|
|
||||||
typer.echo(f"Fixture '{name}' not found.", err=True)
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
typer.echo(f"\nFixture: {fixture.name}")
|
|
||||||
typer.echo(f"Contexts: {', '.join(fixture.contexts)}")
|
|
||||||
typer.echo(
|
|
||||||
f"Dependencies: {', '.join(fixture.depends_on) if fixture.depends_on else 'None'}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Show instances
|
|
||||||
instances = list(fixture.func())
|
|
||||||
if instances:
|
|
||||||
model_name = type(instances[0]).__name__
|
|
||||||
typer.echo(f"\nInstances ({len(instances)} {model_name}):")
|
|
||||||
for instance in instances[:10]: # Limit to 10
|
|
||||||
typer.echo(f" - {instance!r}")
|
|
||||||
if len(instances) > 10:
|
|
||||||
typer.echo(f" ... and {len(instances) - 10} more")
|
|
||||||
else:
|
|
||||||
typer.echo("\nNo instances (empty fixture)")
|
|
||||||
|
|||||||
125
src/fastapi_toolsets/cli/config.py
Normal file
125
src/fastapi_toolsets/cli/config.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
"""CLI configuration and dynamic imports."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
from typing import TYPE_CHECKING, Any, Literal, overload
|
||||||
|
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from .pyproject import find_pyproject, load_pyproject
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..fixtures import FixtureRegistry
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_project_in_path():
|
||||||
|
"""Add project root to sys.path if not installed in editable mode."""
|
||||||
|
pyproject = find_pyproject()
|
||||||
|
if pyproject:
|
||||||
|
project_root = str(pyproject.parent)
|
||||||
|
if project_root not in sys.path:
|
||||||
|
sys.path.insert(0, project_root)
|
||||||
|
|
||||||
|
|
||||||
|
def import_from_string(import_path: str) -> Any:
|
||||||
|
"""Import an object from a dotted string path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
import_path: Import path in ``"module.submodule:attribute"`` format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The imported attribute
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
typer.BadParameter: If the import path is invalid or import fails
|
||||||
|
"""
|
||||||
|
if ":" not in import_path:
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"Invalid import path '{import_path}'. Expected format: 'module:attribute'"
|
||||||
|
)
|
||||||
|
|
||||||
|
module_path, attr_name = import_path.rsplit(":", 1)
|
||||||
|
|
||||||
|
_ensure_project_in_path()
|
||||||
|
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(module_path)
|
||||||
|
except ImportError as e:
|
||||||
|
raise typer.BadParameter(f"Cannot import module '{module_path}': {e}")
|
||||||
|
|
||||||
|
if not hasattr(module, attr_name):
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"Module '{module_path}' has no attribute '{attr_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
return getattr(module, attr_name)
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_config_value(key: str, required: Literal[True]) -> Any: ... # pragma: no cover
|
||||||
|
@overload
|
||||||
|
def get_config_value(
|
||||||
|
key: str, required: bool = False
|
||||||
|
) -> Any | None: ... # pragma: no cover
|
||||||
|
def get_config_value(key: str, required: bool = False) -> Any | None:
|
||||||
|
"""Get a configuration value from pyproject.toml.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The configuration key in [tool.fastapi-toolsets].
|
||||||
|
required: If True, raises an error when the key is missing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The configuration value, or None if not found and not required.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
typer.BadParameter: If required=True and the key is missing.
|
||||||
|
"""
|
||||||
|
config = load_pyproject()
|
||||||
|
value = config.get(key)
|
||||||
|
|
||||||
|
if required and value is None:
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"No '{key}' configured. "
|
||||||
|
f"Add '{key}' to [tool.fastapi-toolsets] in pyproject.toml."
|
||||||
|
)
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def get_fixtures_registry() -> FixtureRegistry:
|
||||||
|
"""Import and return the fixtures registry from config."""
|
||||||
|
from ..fixtures import FixtureRegistry
|
||||||
|
|
||||||
|
import_path = get_config_value("fixtures", required=True)
|
||||||
|
registry = import_from_string(import_path)
|
||||||
|
|
||||||
|
if not isinstance(registry, FixtureRegistry):
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"'fixtures' must be a FixtureRegistry instance, got {type(registry).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return registry
|
||||||
|
|
||||||
|
|
||||||
|
def get_db_context() -> Any:
|
||||||
|
"""Import and return the db_context function from config."""
|
||||||
|
import_path = get_config_value("db_context", required=True)
|
||||||
|
return import_from_string(import_path)
|
||||||
|
|
||||||
|
|
||||||
|
def get_custom_cli() -> typer.Typer | None:
|
||||||
|
"""Import and return the custom CLI Typer instance from config."""
|
||||||
|
import_path = get_config_value("custom_cli")
|
||||||
|
if not import_path:
|
||||||
|
return None
|
||||||
|
|
||||||
|
custom = import_from_string(import_path)
|
||||||
|
|
||||||
|
if not isinstance(custom, typer.Typer):
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"'custom_cli' must be a Typer instance, got {type(custom).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return custom
|
||||||
49
src/fastapi_toolsets/cli/pyproject.py
Normal file
49
src/fastapi_toolsets/cli/pyproject.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
"""Pyproject.toml discovery and loading."""
|
||||||
|
|
||||||
|
import tomllib
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
TOOL_NAME = "fastapi-toolsets"
|
||||||
|
|
||||||
|
|
||||||
|
def find_pyproject(start_path: Path | None = None) -> Path | None:
|
||||||
|
"""Find pyproject.toml by walking up the directory tree.
|
||||||
|
|
||||||
|
Similar to how pytest, black, and ruff discover their config files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_path: Directory to start searching from. Defaults to cwd.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to pyproject.toml, or None if not found.
|
||||||
|
"""
|
||||||
|
path = (start_path or Path.cwd()).resolve()
|
||||||
|
|
||||||
|
for directory in [path, *path.parents]:
|
||||||
|
pyproject = directory / "pyproject.toml"
|
||||||
|
if pyproject.is_file():
|
||||||
|
return pyproject
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def load_pyproject(path: Path | None = None) -> dict:
|
||||||
|
"""Load tool configuration from pyproject.toml.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Explicit path to pyproject.toml. If None, searches up from cwd.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The [tool.fastapi-toolsets] section as a dict, or empty dict if not found.
|
||||||
|
"""
|
||||||
|
pyproject_path = path or find_pyproject()
|
||||||
|
|
||||||
|
if not pyproject_path:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(pyproject_path, "rb") as f:
|
||||||
|
data = tomllib.load(f)
|
||||||
|
return data.get("tool", {}).get(TOOL_NAME, {})
|
||||||
|
except (OSError, tomllib.TOMLDecodeError):
|
||||||
|
return {}
|
||||||
29
src/fastapi_toolsets/cli/utils.py
Normal file
29
src/fastapi_toolsets/cli/utils.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
"""CLI utility functions."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import functools
|
||||||
|
from collections.abc import Callable, Coroutine
|
||||||
|
from typing import Any, ParamSpec, TypeVar
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def async_command(func: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, T]:
|
||||||
|
"""Decorator to run an async function as a sync CLI command.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
@fixture_cli.command("load")
|
||||||
|
@async_command
|
||||||
|
async def load(ctx: typer.Context) -> None:
|
||||||
|
async with get_db_context() as session:
|
||||||
|
await load_fixtures(session, registry)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||||
|
return asyncio.run(func(*args, **kwargs))
|
||||||
|
|
||||||
|
return wrapper
|
||||||
@@ -1,378 +0,0 @@
|
|||||||
"""Generic async CRUD operations for SQLAlchemy models."""
|
|
||||||
|
|
||||||
from collections.abc import Sequence
|
|
||||||
from typing import Any, ClassVar, Generic, Self, TypeVar, cast
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from sqlalchemy import and_, func, select
|
|
||||||
from sqlalchemy import delete as sql_delete
|
|
||||||
from sqlalchemy.dialects.postgresql import insert
|
|
||||||
from sqlalchemy.exc import NoResultFound
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
|
||||||
from sqlalchemy.sql.roles import WhereHavingRole
|
|
||||||
|
|
||||||
from .db import get_transaction
|
|
||||||
from .exceptions import NotFoundError
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"AsyncCrud",
|
|
||||||
"CrudFactory",
|
|
||||||
]
|
|
||||||
|
|
||||||
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncCrud(Generic[ModelType]):
|
|
||||||
"""Generic async CRUD operations for SQLAlchemy models.
|
|
||||||
|
|
||||||
Subclass this and set the `model` class variable, or use `CrudFactory`.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
class UserCrud(AsyncCrud[User]):
|
|
||||||
model = User
|
|
||||||
|
|
||||||
# Or use the factory:
|
|
||||||
UserCrud = CrudFactory(User)
|
|
||||||
|
|
||||||
# Then use it:
|
|
||||||
user = await UserCrud.get(session, [User.id == 1])
|
|
||||||
users = await UserCrud.get_multi(session, limit=10)
|
|
||||||
"""
|
|
||||||
|
|
||||||
model: ClassVar[type[DeclarativeBase]]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def create(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
obj: BaseModel,
|
|
||||||
) -> ModelType:
|
|
||||||
"""Create a new record in the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
obj: Pydantic model with data to create
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created model instance
|
|
||||||
"""
|
|
||||||
async with get_transaction(session):
|
|
||||||
db_model = cls.model(**obj.model_dump())
|
|
||||||
session.add(db_model)
|
|
||||||
await session.refresh(db_model)
|
|
||||||
return cast(ModelType, db_model)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def get(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
filters: list[Any],
|
|
||||||
*,
|
|
||||||
with_for_update: bool = False,
|
|
||||||
load_options: list[Any] | None = None,
|
|
||||||
) -> ModelType:
|
|
||||||
"""Get exactly one record. Raises NotFoundError if not found.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
filters: List of SQLAlchemy filter conditions
|
|
||||||
with_for_update: Lock the row for update
|
|
||||||
load_options: SQLAlchemy loader options (e.g., selectinload)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Model instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NotFoundError: If no record found
|
|
||||||
MultipleResultsFound: If more than one record found
|
|
||||||
"""
|
|
||||||
q = select(cls.model).where(and_(*filters))
|
|
||||||
if load_options:
|
|
||||||
q = q.options(*load_options)
|
|
||||||
if with_for_update:
|
|
||||||
q = q.with_for_update()
|
|
||||||
result = await session.execute(q)
|
|
||||||
item = result.unique().scalar_one_or_none()
|
|
||||||
if not item:
|
|
||||||
raise NotFoundError()
|
|
||||||
return cast(ModelType, item)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def first(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
filters: list[Any] | None = None,
|
|
||||||
*,
|
|
||||||
load_options: list[Any] | None = None,
|
|
||||||
) -> ModelType | None:
|
|
||||||
"""Get the first matching record, or None.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
filters: List of SQLAlchemy filter conditions
|
|
||||||
load_options: SQLAlchemy loader options
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Model instance or None
|
|
||||||
"""
|
|
||||||
q = select(cls.model)
|
|
||||||
if filters:
|
|
||||||
q = q.where(and_(*filters))
|
|
||||||
if load_options:
|
|
||||||
q = q.options(*load_options)
|
|
||||||
result = await session.execute(q)
|
|
||||||
return cast(ModelType | None, result.unique().scalars().first())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def get_multi(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
*,
|
|
||||||
filters: list[Any] | None = None,
|
|
||||||
load_options: list[Any] | None = None,
|
|
||||||
order_by: Any | None = None,
|
|
||||||
limit: int | None = None,
|
|
||||||
offset: int | None = None,
|
|
||||||
) -> Sequence[ModelType]:
|
|
||||||
"""Get multiple records from the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
filters: List of SQLAlchemy filter conditions
|
|
||||||
load_options: SQLAlchemy loader options
|
|
||||||
order_by: Column or list of columns to order by
|
|
||||||
limit: Max number of rows to return
|
|
||||||
offset: Rows to skip
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of model instances
|
|
||||||
"""
|
|
||||||
q = select(cls.model)
|
|
||||||
if filters:
|
|
||||||
q = q.where(and_(*filters))
|
|
||||||
if load_options:
|
|
||||||
q = q.options(*load_options)
|
|
||||||
if order_by is not None:
|
|
||||||
q = q.order_by(order_by)
|
|
||||||
if offset is not None:
|
|
||||||
q = q.offset(offset)
|
|
||||||
if limit is not None:
|
|
||||||
q = q.limit(limit)
|
|
||||||
result = await session.execute(q)
|
|
||||||
return cast(Sequence[ModelType], result.unique().scalars().all())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def update(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
obj: BaseModel,
|
|
||||||
filters: list[Any],
|
|
||||||
*,
|
|
||||||
exclude_unset: bool = True,
|
|
||||||
exclude_none: bool = False,
|
|
||||||
) -> ModelType:
|
|
||||||
"""Update a record in the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
obj: Pydantic model with update data
|
|
||||||
filters: List of SQLAlchemy filter conditions
|
|
||||||
exclude_unset: Exclude fields not explicitly set in the schema
|
|
||||||
exclude_none: Exclude fields with None value
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated model instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NotFoundError: If no record found
|
|
||||||
"""
|
|
||||||
async with get_transaction(session):
|
|
||||||
db_model = await cls.get(session=session, filters=filters)
|
|
||||||
values = obj.model_dump(
|
|
||||||
exclude_unset=exclude_unset, exclude_none=exclude_none
|
|
||||||
)
|
|
||||||
for key, value in values.items():
|
|
||||||
setattr(db_model, key, value)
|
|
||||||
await session.refresh(db_model)
|
|
||||||
return db_model
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def upsert(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
obj: BaseModel,
|
|
||||||
index_elements: list[str],
|
|
||||||
*,
|
|
||||||
set_: BaseModel | None = None,
|
|
||||||
where: WhereHavingRole | None = None,
|
|
||||||
) -> ModelType | None:
|
|
||||||
"""Create or update a record (PostgreSQL only).
|
|
||||||
|
|
||||||
Uses INSERT ... ON CONFLICT for atomic upsert.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
obj: Pydantic model with data
|
|
||||||
index_elements: Columns for ON CONFLICT (unique constraint)
|
|
||||||
set_: Pydantic model for ON CONFLICT DO UPDATE SET
|
|
||||||
where: WHERE clause for ON CONFLICT DO UPDATE
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Model instance
|
|
||||||
"""
|
|
||||||
async with get_transaction(session):
|
|
||||||
values = obj.model_dump(exclude_unset=True)
|
|
||||||
q = insert(cls.model).values(**values)
|
|
||||||
if set_:
|
|
||||||
q = q.on_conflict_do_update(
|
|
||||||
index_elements=index_elements,
|
|
||||||
set_=set_.model_dump(exclude_unset=True),
|
|
||||||
where=where,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
q = q.on_conflict_do_nothing(index_elements=index_elements)
|
|
||||||
q = q.returning(cls.model)
|
|
||||||
result = await session.execute(q)
|
|
||||||
try:
|
|
||||||
db_model = result.unique().scalar_one()
|
|
||||||
except NoResultFound:
|
|
||||||
db_model = await cls.first(
|
|
||||||
session=session,
|
|
||||||
filters=[getattr(cls.model, k) == v for k, v in values.items()],
|
|
||||||
)
|
|
||||||
return cast(ModelType | None, db_model)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def delete(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
filters: list[Any],
|
|
||||||
) -> bool:
|
|
||||||
"""Delete records from the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
filters: List of SQLAlchemy filter conditions
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if deletion was executed
|
|
||||||
"""
|
|
||||||
async with get_transaction(session):
|
|
||||||
q = sql_delete(cls.model).where(and_(*filters))
|
|
||||||
await session.execute(q)
|
|
||||||
return True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def count(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
filters: list[Any] | None = None,
|
|
||||||
) -> int:
|
|
||||||
"""Count records matching the filters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
filters: List of SQLAlchemy filter conditions
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of matching records
|
|
||||||
"""
|
|
||||||
q = select(func.count()).select_from(cls.model)
|
|
||||||
if filters:
|
|
||||||
q = q.where(and_(*filters))
|
|
||||||
result = await session.execute(q)
|
|
||||||
return result.scalar_one()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def exists(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
filters: list[Any],
|
|
||||||
) -> bool:
|
|
||||||
"""Check if a record exists.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
filters: List of SQLAlchemy filter conditions
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if at least one record matches
|
|
||||||
"""
|
|
||||||
q = select(cls.model).where(and_(*filters)).exists().select()
|
|
||||||
result = await session.execute(q)
|
|
||||||
return bool(result.scalar())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def paginate(
|
|
||||||
cls: type[Self],
|
|
||||||
session: AsyncSession,
|
|
||||||
*,
|
|
||||||
filters: list[Any] | None = None,
|
|
||||||
load_options: list[Any] | None = None,
|
|
||||||
order_by: Any | None = None,
|
|
||||||
page: int = 1,
|
|
||||||
items_per_page: int = 20,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Get paginated results with metadata.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: DB async session
|
|
||||||
filters: List of SQLAlchemy filter conditions
|
|
||||||
load_options: SQLAlchemy loader options
|
|
||||||
order_by: Column or list of columns to order by
|
|
||||||
page: Page number (1-indexed)
|
|
||||||
items_per_page: Number of items per page
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict with 'data' and 'pagination' keys
|
|
||||||
"""
|
|
||||||
filters = filters or []
|
|
||||||
offset = (page - 1) * items_per_page
|
|
||||||
|
|
||||||
items = await cls.get_multi(
|
|
||||||
session,
|
|
||||||
filters=filters,
|
|
||||||
load_options=load_options,
|
|
||||||
order_by=order_by,
|
|
||||||
limit=items_per_page,
|
|
||||||
offset=offset,
|
|
||||||
)
|
|
||||||
|
|
||||||
total_count = await cls.count(session, filters=filters)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"data": items,
|
|
||||||
"pagination": {
|
|
||||||
"total_count": total_count,
|
|
||||||
"items_per_page": items_per_page,
|
|
||||||
"page": page,
|
|
||||||
"has_more": page * items_per_page < total_count,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def CrudFactory(
|
|
||||||
model: type[ModelType],
|
|
||||||
) -> type[AsyncCrud[ModelType]]:
|
|
||||||
"""Create a CRUD class for a specific model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: SQLAlchemy model class
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AsyncCrud subclass bound to the model
|
|
||||||
|
|
||||||
Example:
|
|
||||||
from fastapi_toolsets.crud import CrudFactory
|
|
||||||
from myapp.models import User, Post
|
|
||||||
|
|
||||||
UserCrud = CrudFactory(User)
|
|
||||||
PostCrud = CrudFactory(Post)
|
|
||||||
|
|
||||||
# Usage
|
|
||||||
user = await UserCrud.get(session, [User.id == 1])
|
|
||||||
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
|
|
||||||
"""
|
|
||||||
cls = type(f"Async{model.__name__}Crud", (AsyncCrud,), {"model": model})
|
|
||||||
return cast(type[AsyncCrud[ModelType]], cls)
|
|
||||||
17
src/fastapi_toolsets/crud/__init__.py
Normal file
17
src/fastapi_toolsets/crud/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""Generic async CRUD operations for SQLAlchemy models."""
|
||||||
|
|
||||||
|
from ..exceptions import NoSearchableFieldsError
|
||||||
|
from .factory import CrudFactory, JoinType, M2MFieldType
|
||||||
|
from .search import (
|
||||||
|
SearchConfig,
|
||||||
|
get_searchable_fields,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CrudFactory",
|
||||||
|
"get_searchable_fields",
|
||||||
|
"JoinType",
|
||||||
|
"M2MFieldType",
|
||||||
|
"NoSearchableFieldsError",
|
||||||
|
"SearchConfig",
|
||||||
|
]
|
||||||
1101
src/fastapi_toolsets/crud/factory.py
Normal file
1101
src/fastapi_toolsets/crud/factory.py
Normal file
File diff suppressed because it is too large
Load Diff
146
src/fastapi_toolsets/crud/search.py
Normal file
146
src/fastapi_toolsets/crud/search.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
"""Search utilities for AsyncCrud."""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any, Literal
|
||||||
|
|
||||||
|
from sqlalchemy import String, or_
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||||
|
|
||||||
|
from ..exceptions import NoSearchableFieldsError
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sqlalchemy.sql.elements import ColumnElement
|
||||||
|
|
||||||
|
SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SearchConfig:
|
||||||
|
"""Advanced search configuration.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
query: The search string
|
||||||
|
fields: Fields to search (columns or tuples for relationships)
|
||||||
|
case_sensitive: Case-sensitive search (default: False)
|
||||||
|
match_mode: "any" (OR) or "all" (AND) to combine fields
|
||||||
|
"""
|
||||||
|
|
||||||
|
query: str
|
||||||
|
fields: Sequence[SearchFieldType] | None = None
|
||||||
|
case_sensitive: bool = False
|
||||||
|
match_mode: Literal["any", "all"] = "any"
|
||||||
|
|
||||||
|
|
||||||
|
def get_searchable_fields(
|
||||||
|
model: type[DeclarativeBase],
|
||||||
|
*,
|
||||||
|
include_relationships: bool = True,
|
||||||
|
max_depth: int = 1,
|
||||||
|
) -> list[SearchFieldType]:
|
||||||
|
"""Auto-detect String fields on a model and its relationships.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: SQLAlchemy model class
|
||||||
|
include_relationships: Include fields from many-to-one/one-to-one relationships
|
||||||
|
max_depth: Max depth for relationship traversal (default: 1)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of columns and tuples (relationship, column)
|
||||||
|
"""
|
||||||
|
fields: list[SearchFieldType] = []
|
||||||
|
mapper = model.__mapper__
|
||||||
|
|
||||||
|
# Direct String columns
|
||||||
|
for col in mapper.columns:
|
||||||
|
if isinstance(col.type, String):
|
||||||
|
fields.append(getattr(model, col.key))
|
||||||
|
|
||||||
|
# Relationships (one-to-one, many-to-one only)
|
||||||
|
if include_relationships and max_depth > 0:
|
||||||
|
for rel_name, rel_prop in mapper.relationships.items():
|
||||||
|
if rel_prop.uselist: # Skip collections (one-to-many, many-to-many)
|
||||||
|
continue
|
||||||
|
|
||||||
|
rel_attr = getattr(model, rel_name)
|
||||||
|
related_model = rel_prop.mapper.class_
|
||||||
|
|
||||||
|
for col in related_model.__mapper__.columns:
|
||||||
|
if isinstance(col.type, String):
|
||||||
|
fields.append((rel_attr, getattr(related_model, col.key)))
|
||||||
|
|
||||||
|
return fields
|
||||||
|
|
||||||
|
|
||||||
|
def build_search_filters(
|
||||||
|
model: type[DeclarativeBase],
|
||||||
|
search: str | SearchConfig,
|
||||||
|
search_fields: Sequence[SearchFieldType] | None = None,
|
||||||
|
default_fields: Sequence[SearchFieldType] | None = None,
|
||||||
|
) -> tuple[list["ColumnElement[bool]"], list[InstrumentedAttribute[Any]]]:
|
||||||
|
"""Build SQLAlchemy filter conditions for search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: SQLAlchemy model class
|
||||||
|
search: Search string or SearchConfig
|
||||||
|
search_fields: Fields specified per-call (takes priority)
|
||||||
|
default_fields: Default fields (from ClassVar)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (filter_conditions, joins_needed)
|
||||||
|
"""
|
||||||
|
# Normalize input
|
||||||
|
if isinstance(search, str):
|
||||||
|
config = SearchConfig(query=search, fields=search_fields)
|
||||||
|
else:
|
||||||
|
config = search
|
||||||
|
if search_fields is not None:
|
||||||
|
config = SearchConfig(
|
||||||
|
query=config.query,
|
||||||
|
fields=search_fields,
|
||||||
|
case_sensitive=config.case_sensitive,
|
||||||
|
match_mode=config.match_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not config.query or not config.query.strip():
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
# Determine which fields to search
|
||||||
|
fields = config.fields or default_fields or get_searchable_fields(model)
|
||||||
|
|
||||||
|
if not fields:
|
||||||
|
raise NoSearchableFieldsError(model)
|
||||||
|
|
||||||
|
query = config.query.strip()
|
||||||
|
filters: list[ColumnElement[bool]] = []
|
||||||
|
joins: list[InstrumentedAttribute[Any]] = []
|
||||||
|
added_joins: set[str] = set()
|
||||||
|
|
||||||
|
for field in fields:
|
||||||
|
if isinstance(field, tuple):
|
||||||
|
# Relationship: (User.role, Role.name) or deeper
|
||||||
|
for rel in field[:-1]:
|
||||||
|
rel_key = str(rel)
|
||||||
|
if rel_key not in added_joins:
|
||||||
|
joins.append(rel)
|
||||||
|
added_joins.add(rel_key)
|
||||||
|
column = field[-1]
|
||||||
|
else:
|
||||||
|
column = field
|
||||||
|
|
||||||
|
# Build the filter (cast to String for non-text columns)
|
||||||
|
column_as_string = column.cast(String)
|
||||||
|
if config.case_sensitive:
|
||||||
|
filters.append(column_as_string.like(f"%{query}%"))
|
||||||
|
else:
|
||||||
|
filters.append(column_as_string.ilike(f"%{query}%"))
|
||||||
|
|
||||||
|
if not filters:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
# Combine based on match_mode
|
||||||
|
if config.match_mode == "any":
|
||||||
|
return [or_(*filters)], joins
|
||||||
|
else:
|
||||||
|
return filters, joins
|
||||||
@@ -1,8 +1,10 @@
|
|||||||
"""Database utilities: sessions, transactions, and locks."""
|
"""Database utilities: sessions, transactions, and locks."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from collections.abc import AsyncGenerator, Callable
|
from collections.abc import AsyncGenerator, Callable
|
||||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
@@ -14,6 +16,7 @@ __all__ = [
|
|||||||
"create_db_dependency",
|
"create_db_dependency",
|
||||||
"lock_tables",
|
"lock_tables",
|
||||||
"get_transaction",
|
"get_transaction",
|
||||||
|
"wait_for_row_change",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -32,6 +35,7 @@ def create_db_dependency(
|
|||||||
An async generator function usable with FastAPI's Depends()
|
An async generator function usable with FastAPI's Depends()
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||||
from fastapi_toolsets.db import create_db_dependency
|
from fastapi_toolsets.db import create_db_dependency
|
||||||
@@ -43,6 +47,7 @@ def create_db_dependency(
|
|||||||
@app.get("/users")
|
@app.get("/users")
|
||||||
async def list_users(session: AsyncSession = Depends(get_db)):
|
async def list_users(session: AsyncSession = Depends(get_db)):
|
||||||
...
|
...
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||||
@@ -69,6 +74,7 @@ def create_db_context(
|
|||||||
An async context manager function
|
An async context manager function
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||||
from fastapi_toolsets.db import create_db_context
|
from fastapi_toolsets.db import create_db_context
|
||||||
|
|
||||||
@@ -80,6 +86,7 @@ def create_db_context(
|
|||||||
async with get_db_context() as session:
|
async with get_db_context() as session:
|
||||||
user = await UserCrud.get(session, [User.id == 1])
|
user = await UserCrud.get(session, [User.id == 1])
|
||||||
...
|
...
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
get_db = create_db_dependency(session_maker)
|
get_db = create_db_dependency(session_maker)
|
||||||
return asynccontextmanager(get_db)
|
return asynccontextmanager(get_db)
|
||||||
@@ -101,9 +108,11 @@ async def get_transaction(
|
|||||||
The session within the transaction context
|
The session within the transaction context
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
async with get_transaction(session):
|
async with get_transaction(session):
|
||||||
session.add(model)
|
session.add(model)
|
||||||
# Auto-commits on exit, rolls back on exception
|
# Auto-commits on exit, rolls back on exception
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
if session.in_transaction():
|
if session.in_transaction():
|
||||||
async with session.begin_nested():
|
async with session.begin_nested():
|
||||||
@@ -155,6 +164,7 @@ async def lock_tables(
|
|||||||
SQLAlchemyError: If lock cannot be acquired within timeout
|
SQLAlchemyError: If lock cannot be acquired within timeout
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
from fastapi_toolsets.db import lock_tables, LockMode
|
from fastapi_toolsets.db import lock_tables, LockMode
|
||||||
|
|
||||||
async with lock_tables(session, [User, Account]):
|
async with lock_tables(session, [User, Account]):
|
||||||
@@ -166,6 +176,7 @@ async def lock_tables(
|
|||||||
async with lock_tables(session, [Order], mode=LockMode.EXCLUSIVE):
|
async with lock_tables(session, [Order], mode=LockMode.EXCLUSIVE):
|
||||||
# Exclusive lock - no other transactions can access
|
# Exclusive lock - no other transactions can access
|
||||||
await process_order(session, order_id)
|
await process_order(session, order_id)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
table_names = ",".join(table.__tablename__ for table in tables)
|
table_names = ",".join(table.__tablename__ for table in tables)
|
||||||
|
|
||||||
@@ -173,3 +184,85 @@ async def lock_tables(
|
|||||||
await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
|
await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
|
||||||
await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
|
await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
_M = TypeVar("_M", bound=DeclarativeBase)
|
||||||
|
|
||||||
|
|
||||||
|
async def wait_for_row_change(
|
||||||
|
session: AsyncSession,
|
||||||
|
model: type[_M],
|
||||||
|
pk_value: Any,
|
||||||
|
*,
|
||||||
|
columns: list[str] | None = None,
|
||||||
|
interval: float = 0.5,
|
||||||
|
timeout: float | None = None,
|
||||||
|
) -> _M:
|
||||||
|
"""Poll a database row until a change is detected.
|
||||||
|
|
||||||
|
Queries the row every ``interval`` seconds and returns the model instance
|
||||||
|
once a change is detected in any column (or only the specified ``columns``).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: AsyncSession instance
|
||||||
|
model: SQLAlchemy model class
|
||||||
|
pk_value: Primary key value of the row to watch
|
||||||
|
columns: Optional list of column names to watch. If None, all columns
|
||||||
|
are watched.
|
||||||
|
interval: Polling interval in seconds (default: 0.5)
|
||||||
|
timeout: Maximum time to wait in seconds. None means wait forever.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The refreshed model instance with updated values
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
LookupError: If the row does not exist or is deleted during polling
|
||||||
|
TimeoutError: If timeout expires before a change is detected
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import wait_for_row_change
|
||||||
|
|
||||||
|
# Wait for any column to change
|
||||||
|
updated = await wait_for_row_change(session, User, user_id)
|
||||||
|
|
||||||
|
# Watch specific columns with a timeout
|
||||||
|
updated = await wait_for_row_change(
|
||||||
|
session, User, user_id,
|
||||||
|
columns=["status", "email"],
|
||||||
|
interval=1.0,
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
instance = await session.get(model, pk_value)
|
||||||
|
if instance is None:
|
||||||
|
raise LookupError(f"{model.__name__} with pk={pk_value!r} not found")
|
||||||
|
|
||||||
|
if columns is not None:
|
||||||
|
watch_cols = columns
|
||||||
|
else:
|
||||||
|
watch_cols = [attr.key for attr in model.__mapper__.column_attrs]
|
||||||
|
|
||||||
|
initial = {col: getattr(instance, col) for col in watch_cols}
|
||||||
|
|
||||||
|
elapsed = 0.0
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
elapsed += interval
|
||||||
|
|
||||||
|
if timeout is not None and elapsed >= timeout:
|
||||||
|
raise TimeoutError(
|
||||||
|
f"No change detected on {model.__name__} "
|
||||||
|
f"with pk={pk_value!r} within {timeout}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
session.expunge(instance)
|
||||||
|
instance = await session.get(model, pk_value)
|
||||||
|
|
||||||
|
if instance is None:
|
||||||
|
raise LookupError(f"{model.__name__} with pk={pk_value!r} was deleted")
|
||||||
|
|
||||||
|
current = {col: getattr(instance, col) for col in watch_cols}
|
||||||
|
if current != initial:
|
||||||
|
return instance
|
||||||
|
|||||||
145
src/fastapi_toolsets/dependencies.py
Normal file
145
src/fastapi_toolsets/dependencies.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
"""Dependency factories for FastAPI routes."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from collections.abc import AsyncGenerator, Callable
|
||||||
|
from typing import Any, TypeVar, cast
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from .crud import CrudFactory
|
||||||
|
|
||||||
|
__all__ = ["BodyDependency", "PathDependency"]
|
||||||
|
|
||||||
|
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
||||||
|
SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]]
|
||||||
|
|
||||||
|
|
||||||
|
def PathDependency(
|
||||||
|
model: type[ModelType],
|
||||||
|
field: Any,
|
||||||
|
*,
|
||||||
|
session_dep: SessionDependency,
|
||||||
|
param_name: str | None = None,
|
||||||
|
) -> ModelType:
|
||||||
|
"""Create a dependency that fetches a DB object from a path parameter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: SQLAlchemy model class
|
||||||
|
field: Model field to filter by (e.g., User.id)
|
||||||
|
session_dep: Session dependency function (e.g., get_db)
|
||||||
|
param_name: Path parameter name (defaults to model_field, e.g., user_id)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Depends() instance that resolves to the model instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If no matching record is found
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
UserDep = PathDependency(User, User.id, session_dep=get_db)
|
||||||
|
|
||||||
|
@router.get("/user/{id}")
|
||||||
|
async def get(
|
||||||
|
user: User = UserDep,
|
||||||
|
): ...
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
crud = CrudFactory(model)
|
||||||
|
name = (
|
||||||
|
param_name
|
||||||
|
if param_name is not None
|
||||||
|
else "{}_{}".format(model.__name__.lower(), field.key)
|
||||||
|
)
|
||||||
|
python_type = field.type.python_type
|
||||||
|
|
||||||
|
async def dependency(
|
||||||
|
session: AsyncSession = Depends(session_dep), **kwargs: Any
|
||||||
|
) -> ModelType:
|
||||||
|
value = kwargs[name]
|
||||||
|
return await crud.get(session, filters=[field == value])
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
dependency,
|
||||||
|
"__signature__",
|
||||||
|
inspect.Signature(
|
||||||
|
parameters=[
|
||||||
|
inspect.Parameter(
|
||||||
|
name, inspect.Parameter.KEYWORD_ONLY, annotation=python_type
|
||||||
|
),
|
||||||
|
inspect.Parameter(
|
||||||
|
"session",
|
||||||
|
inspect.Parameter.KEYWORD_ONLY,
|
||||||
|
annotation=AsyncSession,
|
||||||
|
default=Depends(session_dep),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return cast(ModelType, Depends(cast(Callable[..., ModelType], dependency)))
|
||||||
|
|
||||||
|
|
||||||
|
def BodyDependency(
|
||||||
|
model: type[ModelType],
|
||||||
|
field: Any,
|
||||||
|
*,
|
||||||
|
session_dep: SessionDependency,
|
||||||
|
body_field: str,
|
||||||
|
) -> ModelType:
|
||||||
|
"""Create a dependency that fetches a DB object from a body field.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: SQLAlchemy model class
|
||||||
|
field: Model field to filter by (e.g., User.id)
|
||||||
|
session_dep: Session dependency function (e.g., get_db)
|
||||||
|
body_field: Name of the field in the request body
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Depends() instance that resolves to the model instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If no matching record is found
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
UserDep = BodyDependency(
|
||||||
|
User, User.ctfd_id, session_dep=get_db, body_field="user_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post("/assign")
|
||||||
|
async def assign(
|
||||||
|
user: User = UserDep,
|
||||||
|
): ...
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
crud = CrudFactory(model)
|
||||||
|
python_type = field.type.python_type
|
||||||
|
|
||||||
|
async def dependency(
|
||||||
|
session: AsyncSession = Depends(session_dep), **kwargs: Any
|
||||||
|
) -> ModelType:
|
||||||
|
value = kwargs[body_field]
|
||||||
|
return await crud.get(session, filters=[field == value])
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
dependency,
|
||||||
|
"__signature__",
|
||||||
|
inspect.Signature(
|
||||||
|
parameters=[
|
||||||
|
inspect.Parameter(
|
||||||
|
body_field, inspect.Parameter.KEYWORD_ONLY, annotation=python_type
|
||||||
|
),
|
||||||
|
inspect.Parameter(
|
||||||
|
"session",
|
||||||
|
inspect.Parameter.KEYWORD_ONLY,
|
||||||
|
annotation=AsyncSession,
|
||||||
|
default=Depends(session_dep),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return cast(ModelType, Depends(cast(Callable[..., ModelType], dependency)))
|
||||||
@@ -1,7 +1,11 @@
|
|||||||
|
"""Standardized API exceptions and error response handlers."""
|
||||||
|
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
|
ApiError,
|
||||||
ApiException,
|
ApiException,
|
||||||
ConflictError,
|
ConflictError,
|
||||||
ForbiddenError,
|
ForbiddenError,
|
||||||
|
NoSearchableFieldsError,
|
||||||
NotFoundError,
|
NotFoundError,
|
||||||
UnauthorizedError,
|
UnauthorizedError,
|
||||||
generate_error_responses,
|
generate_error_responses,
|
||||||
@@ -9,11 +13,13 @@ from .exceptions import (
|
|||||||
from .handler import init_exceptions_handlers
|
from .handler import init_exceptions_handlers
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"init_exceptions_handlers",
|
"ApiError",
|
||||||
"generate_error_responses",
|
|
||||||
"ApiException",
|
"ApiException",
|
||||||
"ConflictError",
|
"ConflictError",
|
||||||
"ForbiddenError",
|
"ForbiddenError",
|
||||||
|
"generate_error_responses",
|
||||||
|
"init_exceptions_handlers",
|
||||||
|
"NoSearchableFieldsError",
|
||||||
"NotFoundError",
|
"NotFoundError",
|
||||||
"UnauthorizedError",
|
"UnauthorizedError",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ class ApiException(Exception):
|
|||||||
The exception handler will use api_error to generate the response.
|
The exception handler will use api_error to generate the response.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
class CustomError(ApiException):
|
class CustomError(ApiException):
|
||||||
api_error = ApiError(
|
api_error = ApiError(
|
||||||
code=400,
|
code=400,
|
||||||
@@ -19,6 +20,7 @@ class ApiException(Exception):
|
|||||||
desc="The request was invalid.",
|
desc="The request was invalid.",
|
||||||
err_code="CUSTOM-400",
|
err_code="CUSTOM-400",
|
||||||
)
|
)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
api_error: ClassVar[ApiError]
|
api_error: ClassVar[ApiError]
|
||||||
@@ -76,47 +78,28 @@ class ConflictError(ApiException):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InsufficientRolesError(ForbiddenError):
|
class NoSearchableFieldsError(ApiException):
|
||||||
"""User does not have the required roles."""
|
"""Raised when search is requested but no searchable fields are available."""
|
||||||
|
|
||||||
api_error = ApiError(
|
api_error = ApiError(
|
||||||
code=403,
|
code=400,
|
||||||
msg="Insufficient Roles",
|
msg="No Searchable Fields",
|
||||||
desc="You do not have the required roles to access this resource.",
|
desc="No searchable fields configured for this resource.",
|
||||||
err_code="RBAC-403",
|
err_code="SEARCH-400",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, required_roles: list[str], user_roles: set[str] | None = None):
|
def __init__(self, model: type) -> None:
|
||||||
self.required_roles = required_roles
|
"""Initialize the exception.
|
||||||
self.user_roles = user_roles
|
|
||||||
|
|
||||||
desc = f"Required roles: {', '.join(required_roles)}"
|
Args:
|
||||||
if user_roles is not None:
|
model: The SQLAlchemy model class that has no searchable fields
|
||||||
desc += f". User has: {', '.join(user_roles) if user_roles else 'no roles'}"
|
"""
|
||||||
|
self.model = model
|
||||||
super().__init__(desc)
|
detail = (
|
||||||
|
f"No searchable fields found for model '{model.__name__}'. "
|
||||||
|
"Provide 'search_fields' parameter or set 'searchable_fields' on the CRUD class."
|
||||||
class UserNotFoundError(NotFoundError):
|
|
||||||
"""User was not found."""
|
|
||||||
|
|
||||||
api_error = ApiError(
|
|
||||||
code=404,
|
|
||||||
msg="User Not Found",
|
|
||||||
desc="The requested user was not found.",
|
|
||||||
err_code="USER-404",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RoleNotFoundError(NotFoundError):
|
|
||||||
"""Role was not found."""
|
|
||||||
|
|
||||||
api_error = ApiError(
|
|
||||||
code=404,
|
|
||||||
msg="Role Not Found",
|
|
||||||
desc="The requested role was not found.",
|
|
||||||
err_code="ROLE-404",
|
|
||||||
)
|
)
|
||||||
|
super().__init__(detail)
|
||||||
|
|
||||||
|
|
||||||
def generate_error_responses(
|
def generate_error_responses(
|
||||||
@@ -133,6 +116,7 @@ def generate_error_responses(
|
|||||||
Dict suitable for FastAPI's responses parameter
|
Dict suitable for FastAPI's responses parameter
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
from fastapi_toolsets.exceptions import generate_error_responses, UnauthorizedError, ForbiddenError
|
from fastapi_toolsets.exceptions import generate_error_responses, UnauthorizedError, ForbiddenError
|
||||||
|
|
||||||
@app.get(
|
@app.get(
|
||||||
@@ -141,6 +125,7 @@ def generate_error_responses(
|
|||||||
)
|
)
|
||||||
async def admin_endpoint():
|
async def admin_endpoint():
|
||||||
...
|
...
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
responses: dict[int | str, dict[str, Any]] = {}
|
responses: dict[int | str, dict[str, Any]] = {}
|
||||||
|
|
||||||
@@ -153,7 +138,7 @@ def generate_error_responses(
|
|||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"example": {
|
"example": {
|
||||||
"data": None,
|
"data": api_error.data,
|
||||||
"status": ResponseStatus.FAIL.value,
|
"status": ResponseStatus.FAIL.value,
|
||||||
"message": api_error.msg,
|
"message": api_error.msg,
|
||||||
"description": api_error.desc,
|
"description": api_error.desc,
|
||||||
|
|||||||
@@ -7,11 +7,32 @@ from fastapi.exceptions import RequestValidationError, ResponseValidationError
|
|||||||
from fastapi.openapi.utils import get_openapi
|
from fastapi.openapi.utils import get_openapi
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
from ..schemas import ResponseStatus
|
from ..schemas import ErrorResponse, ResponseStatus
|
||||||
from .exceptions import ApiException
|
from .exceptions import ApiException
|
||||||
|
|
||||||
|
|
||||||
def init_exceptions_handlers(app: FastAPI) -> FastAPI:
|
def init_exceptions_handlers(app: FastAPI) -> FastAPI:
|
||||||
|
"""Register exception handlers and custom OpenAPI schema on a FastAPI app.
|
||||||
|
|
||||||
|
Installs handlers for :class:`ApiException`, validation errors, and
|
||||||
|
unhandled exceptions, and replaces the default 422 schema with a
|
||||||
|
consistent error format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI application instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The same FastAPI instance (for chaining)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
```
|
||||||
|
"""
|
||||||
_register_exception_handlers(app)
|
_register_exception_handlers(app)
|
||||||
app.openapi = lambda: _custom_openapi(app) # type: ignore[method-assign]
|
app.openapi = lambda: _custom_openapi(app) # type: ignore[method-assign]
|
||||||
return app
|
return app
|
||||||
@@ -35,16 +56,16 @@ def _register_exception_handlers(app: FastAPI) -> None:
|
|||||||
async def api_exception_handler(request: Request, exc: ApiException) -> Response:
|
async def api_exception_handler(request: Request, exc: ApiException) -> Response:
|
||||||
"""Handle custom API exceptions with structured response."""
|
"""Handle custom API exceptions with structured response."""
|
||||||
api_error = exc.api_error
|
api_error = exc.api_error
|
||||||
|
error_response = ErrorResponse(
|
||||||
|
data=api_error.data,
|
||||||
|
message=api_error.msg,
|
||||||
|
description=api_error.desc,
|
||||||
|
error_code=api_error.err_code,
|
||||||
|
)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=api_error.code,
|
status_code=api_error.code,
|
||||||
content={
|
content=error_response.model_dump(),
|
||||||
"data": None,
|
|
||||||
"status": ResponseStatus.FAIL.value,
|
|
||||||
"message": api_error.msg,
|
|
||||||
"description": api_error.desc,
|
|
||||||
"error_code": api_error.err_code,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@app.exception_handler(RequestValidationError)
|
@app.exception_handler(RequestValidationError)
|
||||||
@@ -64,15 +85,15 @@ def _register_exception_handlers(app: FastAPI) -> None:
|
|||||||
@app.exception_handler(Exception)
|
@app.exception_handler(Exception)
|
||||||
async def generic_exception_handler(request: Request, exc: Exception) -> Response:
|
async def generic_exception_handler(request: Request, exc: Exception) -> Response:
|
||||||
"""Handle all unhandled exceptions with a generic 500 response."""
|
"""Handle all unhandled exceptions with a generic 500 response."""
|
||||||
|
error_response = ErrorResponse(
|
||||||
|
message="Internal Server Error",
|
||||||
|
description="An unexpected error occurred. Please try again later.",
|
||||||
|
error_code="SERVER-500",
|
||||||
|
)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
content={
|
content=error_response.model_dump(),
|
||||||
"data": None,
|
|
||||||
"status": ResponseStatus.FAIL.value,
|
|
||||||
"message": "Internal Server Error",
|
|
||||||
"description": "An unexpected error occurred. Please try again later.",
|
|
||||||
"error_code": "SERVER-500",
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -97,15 +118,16 @@ def _format_validation_error(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
error_response = ErrorResponse(
|
||||||
|
data={"errors": formatted_errors},
|
||||||
|
message="Validation Error",
|
||||||
|
description=f"{len(formatted_errors)} validation error(s) detected",
|
||||||
|
error_code="VAL-422",
|
||||||
|
)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
content={
|
content=error_response.model_dump(),
|
||||||
"data": {"errors": formatted_errors},
|
|
||||||
"status": ResponseStatus.FAIL.value,
|
|
||||||
"message": "Validation Error",
|
|
||||||
"description": f"{len(formatted_errors)} validation error(s) detected",
|
|
||||||
"error_code": "VAL-422",
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,8 @@
|
|||||||
from .fixtures import (
|
"""Fixture system for seeding databases with dependency resolution."""
|
||||||
Context,
|
|
||||||
FixtureRegistry,
|
from .enum import LoadStrategy
|
||||||
LoadStrategy,
|
from .registry import Context, FixtureRegistry
|
||||||
load_fixtures,
|
from .utils import get_obj_by_attr, load_fixtures, load_fixtures_by_context
|
||||||
load_fixtures_by_context,
|
|
||||||
)
|
|
||||||
from .utils import get_obj_by_attr
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Context",
|
"Context",
|
||||||
@@ -16,12 +13,3 @@ __all__ = [
|
|||||||
"load_fixtures_by_context",
|
"load_fixtures_by_context",
|
||||||
"register_fixtures",
|
"register_fixtures",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# We lazy-load register_fixtures to avoid needing pytest when using fixtures CLI
|
|
||||||
def __getattr__(name: str):
|
|
||||||
if name == "register_fixtures":
|
|
||||||
from .pytest_plugin import register_fixtures
|
|
||||||
|
|
||||||
return register_fixtures
|
|
||||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
||||||
|
|||||||
32
src/fastapi_toolsets/fixtures/enum.py
Normal file
32
src/fastapi_toolsets/fixtures/enum.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""Enums for fixture loading strategies and contexts."""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class LoadStrategy(str, Enum):
|
||||||
|
"""Strategy for loading fixtures into the database."""
|
||||||
|
|
||||||
|
INSERT = "insert"
|
||||||
|
"""Insert new records. Fails if record already exists."""
|
||||||
|
|
||||||
|
MERGE = "merge"
|
||||||
|
"""Insert or update based on primary key (SQLAlchemy merge)."""
|
||||||
|
|
||||||
|
SKIP_EXISTING = "skip_existing"
|
||||||
|
"""Insert only if record doesn't exist (based on primary key)."""
|
||||||
|
|
||||||
|
|
||||||
|
class Context(str, Enum):
|
||||||
|
"""Predefined fixture contexts."""
|
||||||
|
|
||||||
|
BASE = "base"
|
||||||
|
"""Base fixtures loaded in all environments."""
|
||||||
|
|
||||||
|
PRODUCTION = "production"
|
||||||
|
"""Production-only fixtures."""
|
||||||
|
|
||||||
|
DEVELOPMENT = "development"
|
||||||
|
"""Development fixtures."""
|
||||||
|
|
||||||
|
TESTING = "testing"
|
||||||
|
"""Test fixtures."""
|
||||||
@@ -1,46 +1,15 @@
|
|||||||
"""Fixture system with dependency management and context support."""
|
"""Fixture system with dependency management and context support."""
|
||||||
|
|
||||||
import logging
|
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
from ..db import get_transaction
|
from ..logger import get_logger
|
||||||
|
from .enum import Context
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
class LoadStrategy(str, Enum):
|
|
||||||
"""Strategy for loading fixtures into the database."""
|
|
||||||
|
|
||||||
INSERT = "insert"
|
|
||||||
"""Insert new records. Fails if record already exists."""
|
|
||||||
|
|
||||||
MERGE = "merge"
|
|
||||||
"""Insert or update based on primary key (SQLAlchemy merge)."""
|
|
||||||
|
|
||||||
SKIP_EXISTING = "skip_existing"
|
|
||||||
"""Insert only if record doesn't exist (based on primary key)."""
|
|
||||||
|
|
||||||
|
|
||||||
class Context(str, Enum):
|
|
||||||
"""Predefined fixture contexts."""
|
|
||||||
|
|
||||||
BASE = "base"
|
|
||||||
"""Base fixtures loaded in all environments."""
|
|
||||||
|
|
||||||
PRODUCTION = "production"
|
|
||||||
"""Production-only fixtures."""
|
|
||||||
|
|
||||||
DEVELOPMENT = "development"
|
|
||||||
"""Development fixtures."""
|
|
||||||
|
|
||||||
TESTING = "testing"
|
|
||||||
"""Test fixtures."""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -57,6 +26,7 @@ class FixtureRegistry:
|
|||||||
"""Registry for managing fixtures with dependencies.
|
"""Registry for managing fixtures with dependencies.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
from fastapi_toolsets.fixtures import FixtureRegistry, Context
|
from fastapi_toolsets.fixtures import FixtureRegistry, Context
|
||||||
|
|
||||||
fixtures = FixtureRegistry()
|
fixtures = FixtureRegistry()
|
||||||
@@ -79,10 +49,19 @@ class FixtureRegistry:
|
|||||||
return [
|
return [
|
||||||
Post(id=1, title="Test", user_id=1),
|
Post(id=1, title="Test", user_id=1),
|
||||||
]
|
]
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
contexts: list[str | Context] | None = None,
|
||||||
|
) -> None:
|
||||||
self._fixtures: dict[str, Fixture] = {}
|
self._fixtures: dict[str, Fixture] = {}
|
||||||
|
self._default_contexts: list[str] | None = (
|
||||||
|
[c.value if isinstance(c, Context) else c for c in contexts]
|
||||||
|
if contexts
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
def register(
|
def register(
|
||||||
self,
|
self,
|
||||||
@@ -103,6 +82,7 @@ class FixtureRegistry:
|
|||||||
contexts: List of contexts this fixture belongs to
|
contexts: List of contexts this fixture belongs to
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
@fixtures.register
|
@fixtures.register
|
||||||
def roles():
|
def roles():
|
||||||
return [Role(id=1, name="admin")]
|
return [Role(id=1, name="admin")]
|
||||||
@@ -110,16 +90,21 @@ class FixtureRegistry:
|
|||||||
@fixtures.register(depends_on=["roles"], contexts=[Context.TESTING])
|
@fixtures.register(depends_on=["roles"], contexts=[Context.TESTING])
|
||||||
def test_users():
|
def test_users():
|
||||||
return [User(id=1, username="test", role_id=1)]
|
return [User(id=1, username="test", role_id=1)]
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(
|
def decorator(
|
||||||
fn: Callable[[], Sequence[DeclarativeBase]],
|
fn: Callable[[], Sequence[DeclarativeBase]],
|
||||||
) -> Callable[[], Sequence[DeclarativeBase]]:
|
) -> Callable[[], Sequence[DeclarativeBase]]:
|
||||||
fixture_name = name or cast(Any, fn).__name__
|
fixture_name = name or cast(Any, fn).__name__
|
||||||
|
if contexts is not None:
|
||||||
fixture_contexts = [
|
fixture_contexts = [
|
||||||
c.value if isinstance(c, Context) else c
|
c.value if isinstance(c, Context) else c for c in contexts
|
||||||
for c in (contexts or [Context.BASE])
|
|
||||||
]
|
]
|
||||||
|
elif self._default_contexts is not None:
|
||||||
|
fixture_contexts = self._default_contexts
|
||||||
|
else:
|
||||||
|
fixture_contexts = [Context.BASE.value]
|
||||||
|
|
||||||
self._fixtures[fixture_name] = Fixture(
|
self._fixtures[fixture_name] = Fixture(
|
||||||
name=fixture_name,
|
name=fixture_name,
|
||||||
@@ -133,6 +118,34 @@ class FixtureRegistry:
|
|||||||
return decorator(func)
|
return decorator(func)
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
def include_registry(self, registry: "FixtureRegistry") -> None:
|
||||||
|
"""Include another `FixtureRegistry` in the same current `FixtureRegistry`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registry: The `FixtureRegistry` to include
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If a fixture name already exists in the current registry
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
dev_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@dev_registry.register
|
||||||
|
def dev_data():
|
||||||
|
return [...]
|
||||||
|
|
||||||
|
registry.include_registry(registry=dev_registry)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
for name, fixture in registry._fixtures.items():
|
||||||
|
if name in self._fixtures:
|
||||||
|
raise ValueError(
|
||||||
|
f"Fixture '{name}' already exists in the current registry"
|
||||||
|
)
|
||||||
|
self._fixtures[name] = fixture
|
||||||
|
|
||||||
def get(self, name: str) -> Fixture:
|
def get(self, name: str) -> Fixture:
|
||||||
"""Get a fixture by name."""
|
"""Get a fixture by name."""
|
||||||
if name not in self._fixtures:
|
if name not in self._fixtures:
|
||||||
@@ -204,118 +217,3 @@ class FixtureRegistry:
|
|||||||
all_deps.update(deps)
|
all_deps.update(deps)
|
||||||
|
|
||||||
return self.resolve_dependencies(*all_deps)
|
return self.resolve_dependencies(*all_deps)
|
||||||
|
|
||||||
|
|
||||||
async def load_fixtures(
|
|
||||||
session: AsyncSession,
|
|
||||||
registry: FixtureRegistry,
|
|
||||||
*names: str,
|
|
||||||
strategy: LoadStrategy = LoadStrategy.MERGE,
|
|
||||||
) -> dict[str, list[DeclarativeBase]]:
|
|
||||||
"""Load specific fixtures by name with dependencies.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: Database session
|
|
||||||
registry: Fixture registry
|
|
||||||
*names: Fixture names to load (dependencies auto-resolved)
|
|
||||||
strategy: How to handle existing records
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict mapping fixture names to loaded instances
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# Loads 'roles' first (dependency), then 'users'
|
|
||||||
result = await load_fixtures(session, fixtures, "users")
|
|
||||||
print(result["users"]) # [User(...), ...]
|
|
||||||
"""
|
|
||||||
ordered = registry.resolve_dependencies(*names)
|
|
||||||
return await _load_ordered(session, registry, ordered, strategy)
|
|
||||||
|
|
||||||
|
|
||||||
async def load_fixtures_by_context(
|
|
||||||
session: AsyncSession,
|
|
||||||
registry: FixtureRegistry,
|
|
||||||
*contexts: str | Context,
|
|
||||||
strategy: LoadStrategy = LoadStrategy.MERGE,
|
|
||||||
) -> dict[str, list[DeclarativeBase]]:
|
|
||||||
"""Load all fixtures for specific contexts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: Database session
|
|
||||||
registry: Fixture registry
|
|
||||||
*contexts: Contexts to load (e.g., Context.BASE, Context.TESTING)
|
|
||||||
strategy: How to handle existing records
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict mapping fixture names to loaded instances
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# Load base + testing fixtures
|
|
||||||
await load_fixtures_by_context(
|
|
||||||
session, fixtures,
|
|
||||||
Context.BASE, Context.TESTING
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
ordered = registry.resolve_context_dependencies(*contexts)
|
|
||||||
return await _load_ordered(session, registry, ordered, strategy)
|
|
||||||
|
|
||||||
|
|
||||||
async def _load_ordered(
|
|
||||||
session: AsyncSession,
|
|
||||||
registry: FixtureRegistry,
|
|
||||||
ordered_names: list[str],
|
|
||||||
strategy: LoadStrategy,
|
|
||||||
) -> dict[str, list[DeclarativeBase]]:
|
|
||||||
"""Load fixtures in order."""
|
|
||||||
results: dict[str, list[DeclarativeBase]] = {}
|
|
||||||
|
|
||||||
for name in ordered_names:
|
|
||||||
fixture = registry.get(name)
|
|
||||||
instances = list(fixture.func())
|
|
||||||
|
|
||||||
if not instances:
|
|
||||||
results[name] = []
|
|
||||||
continue
|
|
||||||
|
|
||||||
model_name = type(instances[0]).__name__
|
|
||||||
loaded: list[DeclarativeBase] = []
|
|
||||||
|
|
||||||
async with get_transaction(session):
|
|
||||||
for instance in instances:
|
|
||||||
if strategy == LoadStrategy.INSERT:
|
|
||||||
session.add(instance)
|
|
||||||
loaded.append(instance)
|
|
||||||
|
|
||||||
elif strategy == LoadStrategy.MERGE:
|
|
||||||
merged = await session.merge(instance)
|
|
||||||
loaded.append(merged)
|
|
||||||
|
|
||||||
elif strategy == LoadStrategy.SKIP_EXISTING:
|
|
||||||
pk = _get_primary_key(instance)
|
|
||||||
if pk is not None:
|
|
||||||
existing = await session.get(type(instance), pk)
|
|
||||||
if existing is None:
|
|
||||||
session.add(instance)
|
|
||||||
loaded.append(instance)
|
|
||||||
else:
|
|
||||||
session.add(instance)
|
|
||||||
loaded.append(instance)
|
|
||||||
|
|
||||||
results[name] = loaded
|
|
||||||
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
|
|
||||||
"""Get the primary key value of a model instance."""
|
|
||||||
mapper = instance.__class__.__mapper__
|
|
||||||
pk_cols = mapper.primary_key
|
|
||||||
|
|
||||||
if len(pk_cols) == 1:
|
|
||||||
return getattr(instance, pk_cols[0].name, None)
|
|
||||||
|
|
||||||
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
|
|
||||||
if all(v is not None for v in pk_values):
|
|
||||||
return pk_values
|
|
||||||
return None
|
|
||||||
@@ -1,8 +1,18 @@
|
|||||||
|
"""Fixture loading utilities for database seeding."""
|
||||||
|
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from typing import Any, TypeVar
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from ..db import get_transaction
|
||||||
|
from ..logger import get_logger
|
||||||
|
from .enum import LoadStrategy
|
||||||
|
from .registry import Context, FixtureRegistry
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
T = TypeVar("T", bound=DeclarativeBase)
|
T = TypeVar("T", bound=DeclarativeBase)
|
||||||
|
|
||||||
|
|
||||||
@@ -21,6 +31,130 @@ def get_obj_by_attr(
|
|||||||
The first model instance where the attribute matches the given value.
|
The first model instance where the attribute matches the given value.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
StopIteration: If no matching object is found.
|
StopIteration: If no matching object is found in the fixture group.
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
return next(obj for obj in fixtures() if getattr(obj, attr_name) == value)
|
return next(obj for obj in fixtures() if getattr(obj, attr_name) == value)
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration(
|
||||||
|
f"No object with {attr_name}={value} found in fixture '{getattr(fixtures, '__name__', repr(fixtures))}'"
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
async def load_fixtures(
|
||||||
|
session: AsyncSession,
|
||||||
|
registry: FixtureRegistry,
|
||||||
|
*names: str,
|
||||||
|
strategy: LoadStrategy = LoadStrategy.MERGE,
|
||||||
|
) -> dict[str, list[DeclarativeBase]]:
|
||||||
|
"""Load specific fixtures by name with dependencies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
registry: Fixture registry
|
||||||
|
*names: Fixture names to load (dependencies auto-resolved)
|
||||||
|
strategy: How to handle existing records
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping fixture names to loaded instances
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
# Loads 'roles' first (dependency), then 'users'
|
||||||
|
result = await load_fixtures(session, fixtures, "users")
|
||||||
|
print(result["users"]) # [User(...), ...]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
ordered = registry.resolve_dependencies(*names)
|
||||||
|
return await _load_ordered(session, registry, ordered, strategy)
|
||||||
|
|
||||||
|
|
||||||
|
async def load_fixtures_by_context(
|
||||||
|
session: AsyncSession,
|
||||||
|
registry: FixtureRegistry,
|
||||||
|
*contexts: str | Context,
|
||||||
|
strategy: LoadStrategy = LoadStrategy.MERGE,
|
||||||
|
) -> dict[str, list[DeclarativeBase]]:
|
||||||
|
"""Load all fixtures for specific contexts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
registry: Fixture registry
|
||||||
|
*contexts: Contexts to load (e.g., Context.BASE, Context.TESTING)
|
||||||
|
strategy: How to handle existing records
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping fixture names to loaded instances
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
# Load base + testing fixtures
|
||||||
|
await load_fixtures_by_context(
|
||||||
|
session, fixtures,
|
||||||
|
Context.BASE, Context.TESTING
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
ordered = registry.resolve_context_dependencies(*contexts)
|
||||||
|
return await _load_ordered(session, registry, ordered, strategy)
|
||||||
|
|
||||||
|
|
||||||
|
async def _load_ordered(
|
||||||
|
session: AsyncSession,
|
||||||
|
registry: FixtureRegistry,
|
||||||
|
ordered_names: list[str],
|
||||||
|
strategy: LoadStrategy,
|
||||||
|
) -> dict[str, list[DeclarativeBase]]:
|
||||||
|
"""Load fixtures in order."""
|
||||||
|
results: dict[str, list[DeclarativeBase]] = {}
|
||||||
|
|
||||||
|
for name in ordered_names:
|
||||||
|
fixture = registry.get(name)
|
||||||
|
instances = list(fixture.func())
|
||||||
|
|
||||||
|
if not instances:
|
||||||
|
results[name] = []
|
||||||
|
continue
|
||||||
|
|
||||||
|
model_name = type(instances[0]).__name__
|
||||||
|
loaded: list[DeclarativeBase] = []
|
||||||
|
|
||||||
|
async with get_transaction(session):
|
||||||
|
for instance in instances:
|
||||||
|
if strategy == LoadStrategy.INSERT:
|
||||||
|
session.add(instance)
|
||||||
|
loaded.append(instance)
|
||||||
|
|
||||||
|
elif strategy == LoadStrategy.MERGE:
|
||||||
|
merged = await session.merge(instance)
|
||||||
|
loaded.append(merged)
|
||||||
|
|
||||||
|
elif strategy == LoadStrategy.SKIP_EXISTING:
|
||||||
|
pk = _get_primary_key(instance)
|
||||||
|
if pk is not None:
|
||||||
|
existing = await session.get(type(instance), pk)
|
||||||
|
if existing is None:
|
||||||
|
session.add(instance)
|
||||||
|
loaded.append(instance)
|
||||||
|
else:
|
||||||
|
session.add(instance)
|
||||||
|
loaded.append(instance)
|
||||||
|
|
||||||
|
results[name] = loaded
|
||||||
|
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
|
||||||
|
"""Get the primary key value of a model instance."""
|
||||||
|
mapper = instance.__class__.__mapper__
|
||||||
|
pk_cols = mapper.primary_key
|
||||||
|
|
||||||
|
if len(pk_cols) == 1:
|
||||||
|
return getattr(instance, pk_cols[0].name, None)
|
||||||
|
|
||||||
|
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
|
||||||
|
if all(v is not None for v in pk_values):
|
||||||
|
return pk_values
|
||||||
|
return None
|
||||||
|
|||||||
98
src/fastapi_toolsets/logger.py
Normal file
98
src/fastapi_toolsets/logger.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""Logging configuration for FastAPI applications and CLI tools."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
__all__ = ["LogLevel", "configure_logging", "get_logger"]
|
||||||
|
|
||||||
|
DEFAULT_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
UVICORN_LOGGERS = ("uvicorn", "uvicorn.access", "uvicorn.error")
|
||||||
|
|
||||||
|
LogLevel = Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
||||||
|
|
||||||
|
|
||||||
|
def configure_logging(
|
||||||
|
level: LogLevel | int = "INFO",
|
||||||
|
fmt: str = DEFAULT_FORMAT,
|
||||||
|
logger_name: str | None = None,
|
||||||
|
) -> logging.Logger:
|
||||||
|
"""Configure logging with a stdout handler and consistent format.
|
||||||
|
|
||||||
|
Sets up a :class:`~logging.StreamHandler` writing to stdout with the
|
||||||
|
given format and level. Also configures the uvicorn loggers so that
|
||||||
|
FastAPI access logs use the same format.
|
||||||
|
|
||||||
|
Calling this function multiple times is safe -- existing handlers are
|
||||||
|
replaced rather than duplicated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: Log level (e.g. ``"DEBUG"``, ``"INFO"``, or ``logging.DEBUG``).
|
||||||
|
fmt: Log format string. Defaults to
|
||||||
|
``"%(asctime)s - %(name)s - %(levelname)s - %(message)s"``.
|
||||||
|
logger_name: Logger name to configure. ``None`` (the default)
|
||||||
|
configures the root logger so all loggers inherit the settings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The configured Logger instance.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.logger import configure_logging
|
||||||
|
|
||||||
|
logger = configure_logging("DEBUG")
|
||||||
|
logger.info("Application started")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
formatter = logging.Formatter(fmt)
|
||||||
|
|
||||||
|
handler = logging.StreamHandler(sys.stdout)
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
logger = logging.getLogger(logger_name)
|
||||||
|
logger.handlers.clear()
|
||||||
|
logger.addHandler(handler)
|
||||||
|
logger.setLevel(level)
|
||||||
|
|
||||||
|
for name in UVICORN_LOGGERS:
|
||||||
|
uv_logger = logging.getLogger(name)
|
||||||
|
uv_logger.handlers.clear()
|
||||||
|
uv_logger.addHandler(handler)
|
||||||
|
uv_logger.setLevel(level)
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
_SENTINEL = object()
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name: str | None = _SENTINEL) -> logging.Logger: # type: ignore[assignment]
|
||||||
|
"""Return a logger with the given *name*.
|
||||||
|
|
||||||
|
A thin convenience wrapper around :func:`logging.getLogger` that keeps
|
||||||
|
logging imports consistent across the codebase.
|
||||||
|
|
||||||
|
When called without arguments, the caller's ``__name__`` is used
|
||||||
|
automatically, so ``get_logger()`` in a module is equivalent to
|
||||||
|
``logging.getLogger(__name__)``. Pass ``None`` explicitly to get the
|
||||||
|
root logger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Logger name. Defaults to the caller's ``__name__``.
|
||||||
|
Pass ``None`` to get the root logger.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Logger instance.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger() # uses caller's __name__
|
||||||
|
logger = get_logger("myapp") # explicit name
|
||||||
|
logger = get_logger(None) # root logger
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if name is _SENTINEL:
|
||||||
|
name = sys._getframe(1).f_globals.get("__name__")
|
||||||
|
return logging.getLogger(name)
|
||||||
21
src/fastapi_toolsets/metrics/__init__.py
Normal file
21
src/fastapi_toolsets/metrics/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""Prometheus metrics integration for FastAPI applications."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .registry import Metric, MetricsRegistry
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .handler import init_metrics
|
||||||
|
except ImportError:
|
||||||
|
|
||||||
|
def init_metrics(*_args: Any, **_kwargs: Any) -> None:
|
||||||
|
from .._imports import require_extra
|
||||||
|
|
||||||
|
require_extra(package="prometheus_client", extra="metrics")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Metric",
|
||||||
|
"MetricsRegistry",
|
||||||
|
"init_metrics",
|
||||||
|
]
|
||||||
75
src/fastapi_toolsets/metrics/handler.py
Normal file
75
src/fastapi_toolsets/metrics/handler.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
"""Prometheus metrics endpoint for FastAPI applications."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.responses import Response
|
||||||
|
from prometheus_client import (
|
||||||
|
CONTENT_TYPE_LATEST,
|
||||||
|
CollectorRegistry,
|
||||||
|
generate_latest,
|
||||||
|
multiprocess,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..logger import get_logger
|
||||||
|
from .registry import MetricsRegistry
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_multiprocess() -> bool:
|
||||||
|
"""Check if prometheus multi-process mode is enabled."""
|
||||||
|
return "PROMETHEUS_MULTIPROC_DIR" in os.environ
|
||||||
|
|
||||||
|
|
||||||
|
def init_metrics(
|
||||||
|
app: FastAPI,
|
||||||
|
registry: MetricsRegistry,
|
||||||
|
*,
|
||||||
|
path: str = "/metrics",
|
||||||
|
) -> FastAPI:
|
||||||
|
"""Register a Prometheus ``/metrics`` endpoint on a FastAPI app.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI application instance.
|
||||||
|
registry: A :class:`MetricsRegistry` containing providers and collectors.
|
||||||
|
path: URL path for the metrics endpoint (default ``/metrics``).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The same FastAPI instance (for chaining).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi_toolsets.metrics import MetricsRegistry, init_metrics
|
||||||
|
|
||||||
|
metrics = MetricsRegistry()
|
||||||
|
app = FastAPI()
|
||||||
|
init_metrics(app, registry=metrics)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
for provider in registry.get_providers():
|
||||||
|
logger.debug("Initialising metric provider '%s'", provider.name)
|
||||||
|
provider.func()
|
||||||
|
|
||||||
|
collectors = registry.get_collectors()
|
||||||
|
|
||||||
|
@app.get(path, include_in_schema=False)
|
||||||
|
async def metrics_endpoint() -> Response:
|
||||||
|
for collector in collectors:
|
||||||
|
if asyncio.iscoroutinefunction(collector.func):
|
||||||
|
await collector.func()
|
||||||
|
else:
|
||||||
|
collector.func()
|
||||||
|
|
||||||
|
if _is_multiprocess():
|
||||||
|
prom_registry = CollectorRegistry()
|
||||||
|
multiprocess.MultiProcessCollector(prom_registry)
|
||||||
|
output = generate_latest(prom_registry)
|
||||||
|
else:
|
||||||
|
output = generate_latest()
|
||||||
|
|
||||||
|
return Response(content=output, media_type=CONTENT_TYPE_LATEST)
|
||||||
|
|
||||||
|
return app
|
||||||
128
src/fastapi_toolsets/metrics/registry.py
Normal file
128
src/fastapi_toolsets/metrics/registry.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
"""Metrics registry with decorator-based registration."""
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from ..logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Metric:
|
||||||
|
"""A metric definition with metadata."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
func: Callable[..., Any]
|
||||||
|
collect: bool = field(default=False)
|
||||||
|
|
||||||
|
|
||||||
|
class MetricsRegistry:
|
||||||
|
"""Registry for managing Prometheus metric providers and collectors.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from prometheus_client import Counter, Gauge
|
||||||
|
from fastapi_toolsets.metrics import MetricsRegistry
|
||||||
|
|
||||||
|
metrics = MetricsRegistry()
|
||||||
|
|
||||||
|
@metrics.register
|
||||||
|
def http_requests():
|
||||||
|
return Counter("http_requests_total", "Total HTTP requests", ["method", "status"])
|
||||||
|
|
||||||
|
@metrics.register(name="db_pool")
|
||||||
|
def database_pool_size():
|
||||||
|
return Gauge("db_pool_size", "Database connection pool size")
|
||||||
|
|
||||||
|
@metrics.register(collect=True)
|
||||||
|
def collect_queue_depth(gauge=Gauge("queue_depth", "Current queue depth")):
|
||||||
|
gauge.set(get_current_queue_depth())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._metrics: dict[str, Metric] = {}
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self,
|
||||||
|
func: Callable[..., Any] | None = None,
|
||||||
|
*,
|
||||||
|
name: str | None = None,
|
||||||
|
collect: bool = False,
|
||||||
|
) -> Callable[..., Any]:
|
||||||
|
"""Register a metric provider or collector function.
|
||||||
|
|
||||||
|
Can be used as a decorator with or without arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The metric function to register.
|
||||||
|
name: Metric name (defaults to function name).
|
||||||
|
collect: If ``True``, the function is called on every scrape.
|
||||||
|
If ``False`` (default), called once at init time.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
@metrics.register
|
||||||
|
def my_counter():
|
||||||
|
return Counter("my_counter", "A counter")
|
||||||
|
|
||||||
|
@metrics.register(collect=True, name="queue")
|
||||||
|
def collect_queue_depth():
|
||||||
|
gauge.set(compute_depth())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
|
metric_name = name or cast(Any, fn).__name__
|
||||||
|
self._metrics[metric_name] = Metric(
|
||||||
|
name=metric_name,
|
||||||
|
func=fn,
|
||||||
|
collect=collect,
|
||||||
|
)
|
||||||
|
return fn
|
||||||
|
|
||||||
|
if func is not None:
|
||||||
|
return decorator(func)
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def include_registry(self, registry: "MetricsRegistry") -> None:
|
||||||
|
"""Include another :class:`MetricsRegistry` into this one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registry: The registry to merge in.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If a metric name already exists in the current registry.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
main = MetricsRegistry()
|
||||||
|
sub = MetricsRegistry()
|
||||||
|
|
||||||
|
@sub.register
|
||||||
|
def sub_metric():
|
||||||
|
return Counter("sub_total", "Sub counter")
|
||||||
|
|
||||||
|
main.include_registry(sub)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
for metric_name, definition in registry._metrics.items():
|
||||||
|
if metric_name in self._metrics:
|
||||||
|
raise ValueError(
|
||||||
|
f"Metric '{metric_name}' already exists in the current registry"
|
||||||
|
)
|
||||||
|
self._metrics[metric_name] = definition
|
||||||
|
|
||||||
|
def get_all(self) -> list[Metric]:
|
||||||
|
"""Get all registered metric definitions."""
|
||||||
|
return list(self._metrics.values())
|
||||||
|
|
||||||
|
def get_providers(self) -> list[Metric]:
|
||||||
|
"""Get metric providers (called once at init)."""
|
||||||
|
return [m for m in self._metrics.values() if not m.collect]
|
||||||
|
|
||||||
|
def get_collectors(self) -> list[Metric]:
|
||||||
|
"""Get collectors (called on each scrape)."""
|
||||||
|
return [m for m in self._metrics.values() if m.collect]
|
||||||
30
src/fastapi_toolsets/pytest/__init__.py
Normal file
30
src/fastapi_toolsets/pytest/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""Pytest helpers for FastAPI testing: sessions, clients, and fixtures."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .plugin import register_fixtures
|
||||||
|
except ImportError:
|
||||||
|
from .._imports import require_extra
|
||||||
|
|
||||||
|
require_extra(package="pytest", extra="pytest")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .utils import (
|
||||||
|
cleanup_tables,
|
||||||
|
create_async_client,
|
||||||
|
create_db_session,
|
||||||
|
create_worker_database,
|
||||||
|
worker_database_url,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
from .._imports import require_extra
|
||||||
|
|
||||||
|
require_extra(package="httpx", extra="pytest")
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"cleanup_tables",
|
||||||
|
"create_async_client",
|
||||||
|
"create_db_session",
|
||||||
|
"create_worker_database",
|
||||||
|
"register_fixtures",
|
||||||
|
"worker_database_url",
|
||||||
|
]
|
||||||
@@ -1,55 +1,4 @@
|
|||||||
"""Pytest plugin for using FixtureRegistry fixtures in tests.
|
"""Pytest plugin for using FixtureRegistry fixtures in tests."""
|
||||||
|
|
||||||
This module provides utilities to automatically generate pytest fixtures
|
|
||||||
from your FixtureRegistry, with proper dependency resolution.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# conftest.py
|
|
||||||
import pytest
|
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
|
||||||
|
|
||||||
from app.fixtures import fixtures # Your FixtureRegistry
|
|
||||||
from app.models import Base
|
|
||||||
from fastapi_toolsets.pytest_plugin import register_fixtures
|
|
||||||
|
|
||||||
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/test_db"
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def engine():
|
|
||||||
engine = create_async_engine(DATABASE_URL)
|
|
||||||
yield engine
|
|
||||||
await engine.dispose()
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def db_session(engine):
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
|
||||||
|
|
||||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
||||||
session = session_factory()
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield session
|
|
||||||
finally:
|
|
||||||
await session.close()
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
await conn.run_sync(Base.metadata.drop_all)
|
|
||||||
|
|
||||||
# Automatically generate pytest fixtures from registry
|
|
||||||
# Creates: fixture_roles, fixture_users, fixture_posts, etc.
|
|
||||||
register_fixtures(fixtures, globals())
|
|
||||||
|
|
||||||
Usage in tests:
|
|
||||||
# test_users.py
|
|
||||||
async def test_user_count(db_session, fixture_users):
|
|
||||||
# fixture_users automatically loads fixture_roles first (if dependency)
|
|
||||||
# and returns the list of User models
|
|
||||||
assert len(fixture_users) > 0
|
|
||||||
|
|
||||||
async def test_user_role(db_session, fixture_users):
|
|
||||||
user = fixture_users[0]
|
|
||||||
assert user.role_id is not None
|
|
||||||
"""
|
|
||||||
|
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -59,7 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
from ..db import get_transaction
|
from ..db import get_transaction
|
||||||
from .fixtures import FixtureRegistry, LoadStrategy
|
from ..fixtures import FixtureRegistry, LoadStrategy
|
||||||
|
|
||||||
|
|
||||||
def register_fixtures(
|
def register_fixtures(
|
||||||
@@ -86,6 +35,7 @@ def register_fixtures(
|
|||||||
List of created fixture names
|
List of created fixture names
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
# conftest.py
|
# conftest.py
|
||||||
from app.fixtures import fixtures
|
from app.fixtures import fixtures
|
||||||
from fastapi_toolsets.pytest_plugin import register_fixtures
|
from fastapi_toolsets.pytest_plugin import register_fixtures
|
||||||
@@ -96,6 +46,7 @@ def register_fixtures(
|
|||||||
# - fixture_roles
|
# - fixture_roles
|
||||||
# - fixture_users (depends on fixture_roles if users depends on roles)
|
# - fixture_users (depends on fixture_roles if users depends on roles)
|
||||||
# - fixture_posts (depends on fixture_users if posts depends on users)
|
# - fixture_posts (depends on fixture_users if posts depends on users)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
created_fixtures: list[str] = []
|
created_fixtures: list[str] = []
|
||||||
|
|
||||||
315
src/fastapi_toolsets/pytest/utils.py
Normal file
315
src/fastapi_toolsets/pytest/utils.py
Normal file
@@ -0,0 +1,315 @@
|
|||||||
|
"""Pytest helper utilities for FastAPI testing."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from collections.abc import AsyncGenerator, Callable
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.engine import make_url
|
||||||
|
from sqlalchemy.ext.asyncio import (
|
||||||
|
AsyncSession,
|
||||||
|
async_sessionmaker,
|
||||||
|
create_async_engine,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from ..db import create_db_context
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def create_async_client(
|
||||||
|
app: Any,
|
||||||
|
base_url: str = "http://test",
|
||||||
|
dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] | None = None,
|
||||||
|
) -> AsyncGenerator[AsyncClient, None]:
|
||||||
|
"""Create an async httpx client for testing FastAPI applications.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI application instance.
|
||||||
|
base_url: Base URL for requests. Defaults to "http://test".
|
||||||
|
dependency_overrides: Optional mapping of original dependencies to
|
||||||
|
their test replacements. Applied via ``app.dependency_overrides``
|
||||||
|
before yielding and cleaned up after.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
An AsyncClient configured for the app.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi_toolsets.pytest import create_async_client
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def client():
|
||||||
|
async with create_async_client(app) as c:
|
||||||
|
yield c
|
||||||
|
|
||||||
|
async def test_endpoint(client: AsyncClient):
|
||||||
|
response = await client.get("/health")
|
||||||
|
assert response.status_code == 200
|
||||||
|
```
|
||||||
|
|
||||||
|
Example with dependency overrides:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import create_async_client, create_db_session
|
||||||
|
from app.db import get_db
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session():
|
||||||
|
async with create_db_session(DATABASE_URL, Base, cleanup=True) as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def client(db_session):
|
||||||
|
async def override():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
async with create_async_client(
|
||||||
|
app, dependency_overrides={get_db: override}
|
||||||
|
) as c:
|
||||||
|
yield c
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if dependency_overrides:
|
||||||
|
app.dependency_overrides.update(dependency_overrides)
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
try:
|
||||||
|
async with AsyncClient(transport=transport, base_url=base_url) as client:
|
||||||
|
yield client
|
||||||
|
finally:
|
||||||
|
if dependency_overrides:
|
||||||
|
for key in dependency_overrides:
|
||||||
|
app.dependency_overrides.pop(key, None)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def create_db_session(
|
||||||
|
database_url: str,
|
||||||
|
base: type[DeclarativeBase],
|
||||||
|
*,
|
||||||
|
echo: bool = False,
|
||||||
|
expire_on_commit: bool = False,
|
||||||
|
drop_tables: bool = True,
|
||||||
|
cleanup: bool = False,
|
||||||
|
) -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""Create a database session for testing.
|
||||||
|
|
||||||
|
Creates tables before yielding the session and optionally drops them after.
|
||||||
|
Each call creates a fresh engine and session for test isolation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
database_url: Database connection URL (e.g., "postgresql+asyncpg://...").
|
||||||
|
base: SQLAlchemy DeclarativeBase class containing model metadata.
|
||||||
|
echo: Enable SQLAlchemy query logging. Defaults to False.
|
||||||
|
expire_on_commit: Expire objects after commit. Defaults to False.
|
||||||
|
drop_tables: Drop tables after test. Defaults to True.
|
||||||
|
cleanup: Truncate all tables after test using
|
||||||
|
:func:`cleanup_tables`. Defaults to False.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
An AsyncSession ready for database operations.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import create_db_session
|
||||||
|
from app.models import Base
|
||||||
|
|
||||||
|
DATABASE_URL = "postgresql+asyncpg://user:pass@localhost/test_db"
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session():
|
||||||
|
async with create_db_session(
|
||||||
|
DATABASE_URL, Base, cleanup=True
|
||||||
|
) as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
async def test_create_user(db_session: AsyncSession):
|
||||||
|
user = User(name="test")
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
engine = create_async_engine(database_url, echo=echo)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create tables
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(base.metadata.create_all)
|
||||||
|
|
||||||
|
# Create session using existing db context utility
|
||||||
|
session_maker = async_sessionmaker(engine, expire_on_commit=expire_on_commit)
|
||||||
|
get_session = create_db_context(session_maker)
|
||||||
|
|
||||||
|
async with get_session() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
if cleanup:
|
||||||
|
await cleanup_tables(session, base)
|
||||||
|
|
||||||
|
if drop_tables:
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(base.metadata.drop_all)
|
||||||
|
finally:
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_xdist_worker(default_test_db: str) -> str:
|
||||||
|
"""Return the pytest-xdist worker name, or *default_test_db* when not running under xdist.
|
||||||
|
|
||||||
|
Reads the ``PYTEST_XDIST_WORKER`` environment variable that xdist sets
|
||||||
|
automatically in each worker process (e.g. ``"gw0"``, ``"gw1"``).
|
||||||
|
When xdist is not installed or not active, the variable is absent and
|
||||||
|
*default_test_db* is returned instead.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
default_test_db: Fallback value returned when ``PYTEST_XDIST_WORKER``
|
||||||
|
is not set.
|
||||||
|
"""
|
||||||
|
return os.environ.get("PYTEST_XDIST_WORKER", default_test_db)
|
||||||
|
|
||||||
|
|
||||||
|
def worker_database_url(database_url: str, default_test_db: str) -> str:
|
||||||
|
"""Derive a per-worker database URL for pytest-xdist parallel runs.
|
||||||
|
|
||||||
|
Appends ``_{worker_name}`` to the database name so each xdist worker
|
||||||
|
operates on its own database. When not running under xdist,
|
||||||
|
``_{default_test_db}`` is appended instead.
|
||||||
|
|
||||||
|
The worker name is read from the ``PYTEST_XDIST_WORKER`` environment
|
||||||
|
variable (set automatically by xdist in each worker process).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
database_url: Original database connection URL.
|
||||||
|
default_test_db: Suffix appended to the database name when
|
||||||
|
``PYTEST_XDIST_WORKER`` is not set.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A database URL with a worker- or default-specific database name.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
# With PYTEST_XDIST_WORKER="gw0":
|
||||||
|
url = worker_database_url(
|
||||||
|
"postgresql+asyncpg://user:pass@localhost/test_db",
|
||||||
|
default_test_db="test",
|
||||||
|
)
|
||||||
|
# "postgresql+asyncpg://user:pass@localhost/test_db_gw0"
|
||||||
|
|
||||||
|
# Without PYTEST_XDIST_WORKER:
|
||||||
|
url = worker_database_url(
|
||||||
|
"postgresql+asyncpg://user:pass@localhost/test_db",
|
||||||
|
default_test_db="test",
|
||||||
|
)
|
||||||
|
# "postgresql+asyncpg://user:pass@localhost/test_db_test"
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
worker = _get_xdist_worker(default_test_db=default_test_db)
|
||||||
|
|
||||||
|
url = make_url(database_url)
|
||||||
|
url = url.set(database=f"{url.database}_{worker}")
|
||||||
|
return url.render_as_string(hide_password=False)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def create_worker_database(
|
||||||
|
database_url: str,
|
||||||
|
default_test_db: str = "test_db",
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Create and drop a per-worker database for pytest-xdist isolation.
|
||||||
|
|
||||||
|
Intended for use as a **session-scoped** fixture. Connects to the server
|
||||||
|
using the original *database_url* (with ``AUTOCOMMIT`` isolation for DDL),
|
||||||
|
creates a dedicated database for the worker, and yields the worker-specific
|
||||||
|
URL. On cleanup the worker database is dropped.
|
||||||
|
|
||||||
|
When running under xdist the database name is suffixed with the worker
|
||||||
|
name (e.g. ``_gw0``). Otherwise it is suffixed with *default_test_db*.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
database_url: Original database connection URL.
|
||||||
|
default_test_db: Suffix appended to the database name when
|
||||||
|
``PYTEST_XDIST_WORKER`` is not set. Defaults to ``"test_db"``.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The worker-specific database URL.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import (
|
||||||
|
create_worker_database, create_db_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost/test_db"
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
async def worker_db_url():
|
||||||
|
async with create_worker_database(DATABASE_URL) as url:
|
||||||
|
yield url
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session(worker_db_url):
|
||||||
|
async with create_db_session(
|
||||||
|
worker_db_url, Base, cleanup=True
|
||||||
|
) as session:
|
||||||
|
yield session
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
worker_url = worker_database_url(
|
||||||
|
database_url=database_url, default_test_db=default_test_db
|
||||||
|
)
|
||||||
|
worker_db_name = make_url(worker_url).database
|
||||||
|
|
||||||
|
engine = create_async_engine(
|
||||||
|
database_url,
|
||||||
|
isolation_level="AUTOCOMMIT",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
await conn.execute(text(f"DROP DATABASE IF EXISTS {worker_db_name}"))
|
||||||
|
await conn.execute(text(f"CREATE DATABASE {worker_db_name}"))
|
||||||
|
|
||||||
|
yield worker_url
|
||||||
|
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
await conn.execute(text(f"DROP DATABASE IF EXISTS {worker_db_name}"))
|
||||||
|
finally:
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
async def cleanup_tables(
|
||||||
|
session: AsyncSession,
|
||||||
|
base: type[DeclarativeBase],
|
||||||
|
) -> None:
|
||||||
|
"""Truncate all tables for fast between-test cleanup.
|
||||||
|
|
||||||
|
Executes a single ``TRUNCATE … RESTART IDENTITY CASCADE`` statement
|
||||||
|
across every table in *base*'s metadata, which is significantly faster
|
||||||
|
than dropping and re-creating tables between tests.
|
||||||
|
|
||||||
|
This is a no-op when the metadata contains no tables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: An active async database session.
|
||||||
|
base: SQLAlchemy DeclarativeBase class containing model metadata.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session(worker_db_url):
|
||||||
|
async with create_db_session(worker_db_url, Base) as session:
|
||||||
|
yield session
|
||||||
|
await cleanup_tables(session, Base)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
tables = base.metadata.sorted_tables
|
||||||
|
if not tables:
|
||||||
|
return
|
||||||
|
|
||||||
|
table_names = ", ".join(f'"{t.name}"' for t in tables)
|
||||||
|
await session.execute(text(f"TRUNCATE {table_names} RESTART IDENTITY CASCADE"))
|
||||||
|
await session.commit()
|
||||||
@@ -1,15 +1,18 @@
|
|||||||
"""Base Pydantic schemas for API responses."""
|
"""Base Pydantic schemas for API responses."""
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import ClassVar, Generic, TypeVar
|
from typing import Any, ClassVar, Generic, TypeVar
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ApiError",
|
"ApiError",
|
||||||
|
"CursorPagination",
|
||||||
"ErrorResponse",
|
"ErrorResponse",
|
||||||
|
"OffsetPagination",
|
||||||
"Pagination",
|
"Pagination",
|
||||||
"PaginatedResponse",
|
"PaginatedResponse",
|
||||||
|
"PydanticBase",
|
||||||
"Response",
|
"Response",
|
||||||
"ResponseStatus",
|
"ResponseStatus",
|
||||||
]
|
]
|
||||||
@@ -49,6 +52,7 @@ class ApiError(PydanticBase):
|
|||||||
msg: str
|
msg: str
|
||||||
desc: str
|
desc: str
|
||||||
err_code: str
|
err_code: str
|
||||||
|
data: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
class BaseResponse(PydanticBase):
|
class BaseResponse(PydanticBase):
|
||||||
@@ -69,7 +73,9 @@ class Response(BaseResponse, Generic[DataT]):
|
|||||||
"""Generic API response with data payload.
|
"""Generic API response with data payload.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
Response[UserRead](data=user, message="User retrieved")
|
Response[UserRead](data=user, message="User retrieved")
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: DataT | None = None
|
data: DataT | None = None
|
||||||
@@ -83,11 +89,11 @@ class ErrorResponse(BaseResponse):
|
|||||||
|
|
||||||
status: ResponseStatus = ResponseStatus.FAIL
|
status: ResponseStatus = ResponseStatus.FAIL
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
data: None = None
|
data: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
class Pagination(PydanticBase):
|
class OffsetPagination(PydanticBase):
|
||||||
"""Pagination metadata for list responses.
|
"""Pagination metadata for offset-based list responses.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
total_count: Total number of items across all pages
|
total_count: Total number of items across all pages
|
||||||
@@ -102,15 +108,28 @@ class Pagination(PydanticBase):
|
|||||||
has_more: bool
|
has_more: bool
|
||||||
|
|
||||||
|
|
||||||
class PaginatedResponse(BaseResponse, Generic[DataT]):
|
# Backward-compatible - will be removed in v2.0
|
||||||
"""Paginated API response for list endpoints.
|
Pagination = OffsetPagination
|
||||||
|
|
||||||
Example:
|
|
||||||
PaginatedResponse[UserRead](
|
class CursorPagination(PydanticBase):
|
||||||
data=users,
|
"""Pagination metadata for cursor-based list responses.
|
||||||
pagination=Pagination(total_count=100, items_per_page=10, page=1, has_more=True)
|
|
||||||
)
|
Attributes:
|
||||||
|
next_cursor: Encoded cursor for the next page, or None on the last page.
|
||||||
|
prev_cursor: Encoded cursor for the previous page, or None on the first page.
|
||||||
|
items_per_page: Number of items requested per page.
|
||||||
|
has_more: Whether there is at least one more page after this one.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
next_cursor: str | None
|
||||||
|
prev_cursor: str | None = None
|
||||||
|
items_per_page: int
|
||||||
|
has_more: bool
|
||||||
|
|
||||||
|
|
||||||
|
class PaginatedResponse(BaseResponse, Generic[DataT]):
|
||||||
|
"""Paginated API response for list endpoints."""
|
||||||
|
|
||||||
data: list[DataT]
|
data: list[DataT]
|
||||||
pagination: Pagination
|
pagination: OffsetPagination | CursorPagination
|
||||||
|
|||||||
@@ -1,27 +1,23 @@
|
|||||||
"""Shared pytest fixtures for fastapi-utils tests."""
|
"""Shared pytest fixtures for fastapi-utils tests."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import ForeignKey, String
|
from sqlalchemy import Column, ForeignKey, Integer, String, Table, Uuid
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from fastapi_toolsets.crud import CrudFactory
|
from fastapi_toolsets.crud import CrudFactory
|
||||||
|
from fastapi_toolsets.schemas import PydanticBase
|
||||||
|
|
||||||
# PostgreSQL connection URL from environment or default for local development
|
DATABASE_URL = os.getenv(
|
||||||
DATABASE_URL = os.getenv("DATABASE_URL") or os.getenv(
|
key="DATABASE_URL",
|
||||||
"TEST_DATABASE_URL",
|
default="postgresql+asyncpg://postgres:postgres@localhost:5432/postgres",
|
||||||
"postgresql+asyncpg://postgres:postgres@localhost:5432/fastapi_toolsets_test",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Test Models
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class Base(DeclarativeBase):
|
class Base(DeclarativeBase):
|
||||||
"""Base class for test models."""
|
"""Base class for test models."""
|
||||||
|
|
||||||
@@ -33,7 +29,7 @@ class Role(Base):
|
|||||||
|
|
||||||
__tablename__ = "roles"
|
__tablename__ = "roles"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True)
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
name: Mapped[str] = mapped_column(String(50), unique=True)
|
name: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
|
|
||||||
users: Mapped[list["User"]] = relationship(back_populates="role")
|
users: Mapped[list["User"]] = relationship(back_populates="role")
|
||||||
@@ -44,36 +40,70 @@ class User(Base):
|
|||||||
|
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True)
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
username: Mapped[str] = mapped_column(String(50), unique=True)
|
username: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
email: Mapped[str] = mapped_column(String(100), unique=True)
|
email: Mapped[str] = mapped_column(String(100), unique=True)
|
||||||
is_active: Mapped[bool] = mapped_column(default=True)
|
is_active: Mapped[bool] = mapped_column(default=True)
|
||||||
role_id: Mapped[int | None] = mapped_column(ForeignKey("roles.id"), nullable=True)
|
role_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
|
ForeignKey("roles.id"), nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
role: Mapped[Role | None] = relationship(back_populates="users")
|
role: Mapped[Role | None] = relationship(back_populates="users")
|
||||||
|
|
||||||
|
|
||||||
|
class Tag(Base):
|
||||||
|
"""Test tag model."""
|
||||||
|
|
||||||
|
__tablename__ = "tags"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
|
name: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
|
|
||||||
|
|
||||||
|
post_tags = Table(
|
||||||
|
"post_tags",
|
||||||
|
Base.metadata,
|
||||||
|
Column(
|
||||||
|
"post_id", Uuid, ForeignKey("posts.id", ondelete="CASCADE"), primary_key=True
|
||||||
|
),
|
||||||
|
Column("tag_id", Uuid, ForeignKey("tags.id", ondelete="CASCADE"), primary_key=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IntRole(Base):
|
||||||
|
"""Test role model with auto-increment integer PK."""
|
||||||
|
|
||||||
|
__tablename__ = "int_roles"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
name: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
|
|
||||||
|
|
||||||
class Post(Base):
|
class Post(Base):
|
||||||
"""Test post model."""
|
"""Test post model."""
|
||||||
|
|
||||||
__tablename__ = "posts"
|
__tablename__ = "posts"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True)
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
title: Mapped[str] = mapped_column(String(200))
|
title: Mapped[str] = mapped_column(String(200))
|
||||||
content: Mapped[str] = mapped_column(String(1000), default="")
|
content: Mapped[str] = mapped_column(String(1000), default="")
|
||||||
is_published: Mapped[bool] = mapped_column(default=False)
|
is_published: Mapped[bool] = mapped_column(default=False)
|
||||||
author_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
|
author_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
|
||||||
|
|
||||||
|
tags: Mapped[list[Tag]] = relationship(secondary=post_tags)
|
||||||
# =============================================================================
|
|
||||||
# Test Schemas
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class RoleCreate(BaseModel):
|
class RoleCreate(BaseModel):
|
||||||
"""Schema for creating a role."""
|
"""Schema for creating a role."""
|
||||||
|
|
||||||
id: int | None = None
|
id: uuid.UUID | None = None
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class RoleRead(PydanticBase):
|
||||||
|
"""Schema for reading a role."""
|
||||||
|
|
||||||
|
id: uuid.UUID
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
@@ -86,11 +116,18 @@ class RoleUpdate(BaseModel):
|
|||||||
class UserCreate(BaseModel):
|
class UserCreate(BaseModel):
|
||||||
"""Schema for creating a user."""
|
"""Schema for creating a user."""
|
||||||
|
|
||||||
id: int | None = None
|
id: uuid.UUID | None = None
|
||||||
username: str
|
username: str
|
||||||
email: str
|
email: str
|
||||||
is_active: bool = True
|
is_active: bool = True
|
||||||
role_id: int | None = None
|
role_id: uuid.UUID | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class UserRead(PydanticBase):
|
||||||
|
"""Schema for reading a user (subset of fields — no email)."""
|
||||||
|
|
||||||
|
id: uuid.UUID
|
||||||
|
username: str
|
||||||
|
|
||||||
|
|
||||||
class UserUpdate(BaseModel):
|
class UserUpdate(BaseModel):
|
||||||
@@ -99,17 +136,24 @@ class UserUpdate(BaseModel):
|
|||||||
username: str | None = None
|
username: str | None = None
|
||||||
email: str | None = None
|
email: str | None = None
|
||||||
is_active: bool | None = None
|
is_active: bool | None = None
|
||||||
role_id: int | None = None
|
role_id: uuid.UUID | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TagCreate(BaseModel):
|
||||||
|
"""Schema for creating a tag."""
|
||||||
|
|
||||||
|
id: uuid.UUID | None = None
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
class PostCreate(BaseModel):
|
class PostCreate(BaseModel):
|
||||||
"""Schema for creating a post."""
|
"""Schema for creating a post."""
|
||||||
|
|
||||||
id: int | None = None
|
id: uuid.UUID | None = None
|
||||||
title: str
|
title: str
|
||||||
content: str = ""
|
content: str = ""
|
||||||
is_published: bool = False
|
is_published: bool = False
|
||||||
author_id: int
|
author_id: uuid.UUID
|
||||||
|
|
||||||
|
|
||||||
class PostUpdate(BaseModel):
|
class PostUpdate(BaseModel):
|
||||||
@@ -120,18 +164,40 @@ class PostUpdate(BaseModel):
|
|||||||
is_published: bool | None = None
|
is_published: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
class PostM2MCreate(BaseModel):
|
||||||
# CRUD Classes
|
"""Schema for creating a post with M2M tag IDs."""
|
||||||
# =============================================================================
|
|
||||||
|
id: uuid.UUID | None = None
|
||||||
|
title: str
|
||||||
|
content: str = ""
|
||||||
|
is_published: bool = False
|
||||||
|
author_id: uuid.UUID
|
||||||
|
tag_ids: list[uuid.UUID] = []
|
||||||
|
|
||||||
|
|
||||||
|
class PostM2MUpdate(BaseModel):
|
||||||
|
"""Schema for updating a post with M2M tag IDs."""
|
||||||
|
|
||||||
|
title: str | None = None
|
||||||
|
content: str | None = None
|
||||||
|
is_published: bool | None = None
|
||||||
|
tag_ids: list[uuid.UUID] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class IntRoleCreate(BaseModel):
|
||||||
|
"""Schema for creating an IntRole."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
RoleCrud = CrudFactory(Role)
|
RoleCrud = CrudFactory(Role)
|
||||||
|
RoleCursorCrud = CrudFactory(Role, cursor_column=Role.id)
|
||||||
|
IntRoleCursorCrud = CrudFactory(IntRole, cursor_column=IntRole.id)
|
||||||
UserCrud = CrudFactory(User)
|
UserCrud = CrudFactory(User)
|
||||||
|
UserCursorCrud = CrudFactory(User, cursor_column=User.id)
|
||||||
PostCrud = CrudFactory(Post)
|
PostCrud = CrudFactory(Post)
|
||||||
|
TagCrud = CrudFactory(Tag)
|
||||||
|
PostM2MCrud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags})
|
||||||
# =============================================================================
|
|
||||||
# Fixtures
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -195,5 +261,5 @@ def sample_post_data() -> PostCreate:
|
|||||||
title="Test Post",
|
title="Test Post",
|
||||||
content="Test content",
|
content="Test content",
|
||||||
is_published=True,
|
is_published=True,
|
||||||
author_id=1,
|
author_id=uuid.uuid4(),
|
||||||
)
|
)
|
||||||
|
|||||||
543
tests/test_cli.py
Normal file
543
tests/test_cli.py
Normal file
@@ -0,0 +1,543 @@
|
|||||||
|
"""Tests for fastapi_toolsets.cli module."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli.config import (
|
||||||
|
get_config_value,
|
||||||
|
get_custom_cli,
|
||||||
|
get_db_context,
|
||||||
|
get_fixtures_registry,
|
||||||
|
import_from_string,
|
||||||
|
)
|
||||||
|
from fastapi_toolsets.cli.pyproject import find_pyproject, load_pyproject
|
||||||
|
from fastapi_toolsets.cli.utils import async_command
|
||||||
|
from fastapi_toolsets.fixtures import FixtureRegistry
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
|
||||||
|
|
||||||
|
class TestPyproject:
|
||||||
|
"""Tests for pyproject.toml discovery and loading."""
|
||||||
|
|
||||||
|
def test_find_pyproject_in_current_dir(self, tmp_path, monkeypatch):
|
||||||
|
"""Finds pyproject.toml in current directory."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[project]\nname = 'test'\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = find_pyproject()
|
||||||
|
assert result == pyproject
|
||||||
|
|
||||||
|
def test_find_pyproject_in_parent_dir(self, tmp_path, monkeypatch):
|
||||||
|
"""Finds pyproject.toml in parent directory."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[project]\nname = 'test'\n")
|
||||||
|
subdir = tmp_path / "src" / "app"
|
||||||
|
subdir.mkdir(parents=True)
|
||||||
|
monkeypatch.chdir(subdir)
|
||||||
|
|
||||||
|
result = find_pyproject()
|
||||||
|
assert result == pyproject
|
||||||
|
|
||||||
|
def test_find_pyproject_not_found(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns None when no pyproject.toml exists."""
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
result = find_pyproject()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_load_pyproject_returns_tool_config(self, tmp_path, monkeypatch):
|
||||||
|
"""load_pyproject returns the [tool.fastapi-toolsets] section."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text(
|
||||||
|
'[tool.fastapi-toolsets]\nfixtures = "app.fixtures:registry"\n'
|
||||||
|
)
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = load_pyproject()
|
||||||
|
assert result == {"fixtures": "app.fixtures:registry"}
|
||||||
|
|
||||||
|
def test_load_pyproject_empty_when_no_file(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns empty dict when no pyproject.toml exists."""
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
result = load_pyproject()
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
def test_load_pyproject_empty_when_no_tool_section(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns empty dict when no [tool.fastapi-toolsets] section."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[project]\nname = 'test'\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = load_pyproject()
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
def test_load_pyproject_invalid_toml(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns empty dict when pyproject.toml is invalid."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("invalid toml {{{")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = load_pyproject()
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestImportFromString:
|
||||||
|
"""Tests for import_from_string function."""
|
||||||
|
|
||||||
|
def test_import_valid_path(self):
|
||||||
|
"""Import valid module:attribute path."""
|
||||||
|
result = import_from_string("fastapi_toolsets.fixtures:FixtureRegistry")
|
||||||
|
assert result is FixtureRegistry
|
||||||
|
|
||||||
|
def test_import_without_colon_raises_error(self):
|
||||||
|
"""Import path without colon raises error."""
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
import_from_string("fastapi_toolsets.fixtures.FixtureRegistry")
|
||||||
|
assert "Expected format: 'module:attribute'" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_import_nonexistent_module_raises_error(self):
|
||||||
|
"""Import nonexistent module raises error."""
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
import_from_string("nonexistent.module:something")
|
||||||
|
assert "Cannot import module" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_import_nonexistent_attribute_raises_error(self):
|
||||||
|
"""Import nonexistent attribute raises error."""
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
import_from_string("fastapi_toolsets.fixtures:NonexistentClass")
|
||||||
|
assert "has no attribute" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetConfigValue:
|
||||||
|
"""Tests for get_config_value function."""
|
||||||
|
|
||||||
|
def test_get_existing_value(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns value when key exists."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text('[tool.fastapi-toolsets]\nfixtures = "app:registry"\n')
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = get_config_value("fixtures")
|
||||||
|
assert result == "app:registry"
|
||||||
|
|
||||||
|
def test_get_missing_value_returns_none(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns None when key is missing and not required."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = get_config_value("fixtures")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_get_missing_value_required_raises_error(self, tmp_path, monkeypatch):
|
||||||
|
"""Raises error when key is missing and required."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
get_config_value("fixtures", required=True)
|
||||||
|
assert "No 'fixtures' configured" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetFixturesRegistry:
|
||||||
|
"""Tests for get_fixtures_registry function."""
|
||||||
|
|
||||||
|
def test_raises_when_not_configured(self, tmp_path, monkeypatch):
|
||||||
|
"""Raises error when fixtures not configured."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
get_fixtures_registry()
|
||||||
|
assert "No 'fixtures' configured" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_raises_when_not_registry_instance(self, tmp_path, monkeypatch):
|
||||||
|
"""Raises error when imported object is not a FixtureRegistry."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text(
|
||||||
|
'[tool.fastapi-toolsets]\nfixtures = "my_fixtures:registry"\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
fixtures_file = tmp_path / "my_fixtures.py"
|
||||||
|
fixtures_file.write_text("registry = 'not a registry'\n")
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
if str(tmp_path) not in sys.path:
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
get_fixtures_registry()
|
||||||
|
assert "must be a FixtureRegistry instance" in str(exc_info.value)
|
||||||
|
finally:
|
||||||
|
if str(tmp_path) in sys.path:
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
if "my_fixtures" in sys.modules:
|
||||||
|
del sys.modules["my_fixtures"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetDbContext:
|
||||||
|
"""Tests for get_db_context function."""
|
||||||
|
|
||||||
|
def test_raises_when_not_configured(self, tmp_path, monkeypatch):
|
||||||
|
"""Raises error when db_context not configured."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
get_db_context()
|
||||||
|
assert "No 'db_context' configured" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetCustomCli:
|
||||||
|
"""Tests for get_custom_cli function."""
|
||||||
|
|
||||||
|
def test_returns_none_when_not_configured(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns None when custom_cli not configured."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = get_custom_cli()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_raises_when_not_typer_instance(self, tmp_path, monkeypatch):
|
||||||
|
"""Raises error when imported object is not a Typer instance."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text('[tool.fastapi-toolsets]\ncustom_cli = "my_cli:cli"\n')
|
||||||
|
|
||||||
|
cli_file = tmp_path / "my_cli.py"
|
||||||
|
cli_file.write_text("cli = 'not a typer'\n")
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
if str(tmp_path) not in sys.path:
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
get_custom_cli()
|
||||||
|
assert "must be a Typer instance" in str(exc_info.value)
|
||||||
|
finally:
|
||||||
|
if str(tmp_path) in sys.path:
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
if "my_cli" in sys.modules:
|
||||||
|
del sys.modules["my_cli"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestCliApp:
|
||||||
|
"""Tests for CLI application."""
|
||||||
|
|
||||||
|
def test_cli_help(self, tmp_path, monkeypatch):
|
||||||
|
"""CLI shows help without fixtures."""
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
# Need to reload the module to pick up new cwd
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app
|
||||||
|
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
result = runner.invoke(app.cli, ["--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "CLI utilities for FastAPI projects" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
class TestFixturesCli:
|
||||||
|
"""Tests for fixtures CLI commands."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def cli_env(self, tmp_path, monkeypatch):
|
||||||
|
"""Set up CLI environment with fixtures config."""
|
||||||
|
# Create pyproject.toml
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text(
|
||||||
|
"[tool.fastapi-toolsets]\n"
|
||||||
|
'fixtures = "fixtures:registry"\n'
|
||||||
|
'db_context = "db:get_session"\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create fixtures module
|
||||||
|
fixtures_file = tmp_path / "fixtures.py"
|
||||||
|
fixtures_file.write_text(
|
||||||
|
"from fastapi_toolsets.fixtures import FixtureRegistry, Context\n"
|
||||||
|
"\n"
|
||||||
|
"registry = FixtureRegistry()\n"
|
||||||
|
"\n"
|
||||||
|
"@registry.register(contexts=[Context.BASE])\n"
|
||||||
|
"def roles():\n"
|
||||||
|
' return [{"id": 1, "name": "admin"}, {"id": 2, "name": "user"}]\n'
|
||||||
|
"\n"
|
||||||
|
'@registry.register(depends_on=["roles"], contexts=[Context.TESTING])\n'
|
||||||
|
"def users():\n"
|
||||||
|
' return [{"id": 1, "name": "alice", "role_id": 1}]\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create db module
|
||||||
|
db_file = tmp_path / "db.py"
|
||||||
|
db_file.write_text(
|
||||||
|
"from contextlib import asynccontextmanager\n"
|
||||||
|
"\n"
|
||||||
|
"@asynccontextmanager\n"
|
||||||
|
"async def get_session():\n"
|
||||||
|
" yield None\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
# Add tmp_path to sys.path for imports
|
||||||
|
if str(tmp_path) not in sys.path:
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
# Reload the CLI module to pick up new config
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app
|
||||||
|
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
yield tmp_path, app.cli
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
if str(tmp_path) in sys.path:
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
|
||||||
|
def test_fixtures_list(self, cli_env):
|
||||||
|
"""fixtures list shows registered fixtures."""
|
||||||
|
tmp_path, cli = cli_env
|
||||||
|
result = runner.invoke(cli, ["fixtures", "list"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "roles" in result.output
|
||||||
|
assert "users" in result.output
|
||||||
|
assert "Total: 2 fixture(s)" in result.output
|
||||||
|
|
||||||
|
def test_fixtures_list_with_context(self, cli_env):
|
||||||
|
"""fixtures list --context filters by context."""
|
||||||
|
tmp_path, cli = cli_env
|
||||||
|
result = runner.invoke(cli, ["fixtures", "list", "--context", "base"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "roles" in result.output
|
||||||
|
assert "users" not in result.output
|
||||||
|
assert "Total: 1 fixture(s)" in result.output
|
||||||
|
|
||||||
|
def test_fixtures_load_dry_run(self, cli_env):
|
||||||
|
"""fixtures load --dry-run shows what would be loaded."""
|
||||||
|
tmp_path, cli = cli_env
|
||||||
|
result = runner.invoke(cli, ["fixtures", "load", "base", "--dry-run"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Fixtures to load" in result.output
|
||||||
|
assert "roles" in result.output
|
||||||
|
assert "[Dry run - no changes made]" in result.output
|
||||||
|
|
||||||
|
def test_fixtures_load_invalid_strategy(self, cli_env):
|
||||||
|
"""fixtures load with invalid strategy shows error."""
|
||||||
|
tmp_path, cli = cli_env
|
||||||
|
result = runner.invoke(
|
||||||
|
cli, ["fixtures", "load", "base", "--strategy", "invalid"]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code != 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestCliWithoutFixturesConfig:
|
||||||
|
"""Tests for CLI when fixtures is not configured."""
|
||||||
|
|
||||||
|
def test_no_fixtures_command(self, tmp_path, monkeypatch):
|
||||||
|
"""fixtures command is not available when not configured."""
|
||||||
|
# Create pyproject.toml without fixtures
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text('[project]\nname = "test"\n')
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
# Reload the CLI module
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app
|
||||||
|
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
result = runner.invoke(app.cli, ["--help"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "fixtures" not in result.output
|
||||||
|
|
||||||
|
|
||||||
|
class TestCustomCliConfig:
|
||||||
|
"""Tests for custom CLI configuration."""
|
||||||
|
|
||||||
|
def test_cli_with_custom_cli(self, tmp_path, monkeypatch):
|
||||||
|
"""CLI uses custom Typer instance when configured."""
|
||||||
|
import typer
|
||||||
|
|
||||||
|
# Create pyproject.toml with custom_cli config
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text('[tool.fastapi-toolsets]\ncustom_cli = "my_cli:cli"\n')
|
||||||
|
|
||||||
|
# Create custom CLI module with its own Typer and commands
|
||||||
|
cli_file = tmp_path / "my_cli.py"
|
||||||
|
cli_file.write_text(
|
||||||
|
"import typer\n"
|
||||||
|
"\n"
|
||||||
|
"cli = typer.Typer(name='my-app', help='My custom CLI')\n"
|
||||||
|
"\n"
|
||||||
|
"@cli.command()\n"
|
||||||
|
"def hello():\n"
|
||||||
|
' print("Hello from custom CLI!")\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
# Add tmp_path to sys.path for imports
|
||||||
|
if str(tmp_path) not in sys.path:
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
# Remove my_cli from sys.modules if it was previously loaded
|
||||||
|
if "my_cli" in sys.modules:
|
||||||
|
del sys.modules["my_cli"]
|
||||||
|
|
||||||
|
# Reload the CLI module to pick up new config
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app
|
||||||
|
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Verify custom CLI is used
|
||||||
|
assert isinstance(app.cli, typer.Typer)
|
||||||
|
|
||||||
|
result = runner.invoke(app.cli, ["--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "My custom CLI" in result.output
|
||||||
|
assert "hello" in result.output
|
||||||
|
|
||||||
|
result = runner.invoke(app.cli, ["hello"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Hello from custom CLI!" in result.output
|
||||||
|
finally:
|
||||||
|
if str(tmp_path) in sys.path:
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
if "my_cli" in sys.modules:
|
||||||
|
del sys.modules["my_cli"]
|
||||||
|
|
||||||
|
def test_custom_cli_with_fixtures(self, tmp_path, monkeypatch):
|
||||||
|
"""Custom CLI gets fixtures command added when configured."""
|
||||||
|
# Create pyproject.toml with both custom_cli and fixtures
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text(
|
||||||
|
"[tool.fastapi-toolsets]\n"
|
||||||
|
'custom_cli = "my_cli:cli"\n'
|
||||||
|
'fixtures = "fixtures:registry"\n'
|
||||||
|
'db_context = "db:get_session"\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create custom CLI module
|
||||||
|
cli_file = tmp_path / "my_cli.py"
|
||||||
|
cli_file.write_text(
|
||||||
|
"import typer\n"
|
||||||
|
"\n"
|
||||||
|
"cli = typer.Typer(name='my-app', help='My custom CLI')\n"
|
||||||
|
"\n"
|
||||||
|
"@cli.command()\n"
|
||||||
|
"def hello():\n"
|
||||||
|
' print("Hello!")\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create fixtures module
|
||||||
|
fixtures_file = tmp_path / "fixtures.py"
|
||||||
|
fixtures_file.write_text(
|
||||||
|
"from fastapi_toolsets.fixtures import FixtureRegistry\n"
|
||||||
|
"\n"
|
||||||
|
"registry = FixtureRegistry()\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create db module
|
||||||
|
db_file = tmp_path / "db.py"
|
||||||
|
db_file.write_text(
|
||||||
|
"from contextlib import asynccontextmanager\n"
|
||||||
|
"\n"
|
||||||
|
"@asynccontextmanager\n"
|
||||||
|
"async def get_session():\n"
|
||||||
|
" yield None\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
if str(tmp_path) not in sys.path:
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
for mod in ["my_cli", "fixtures", "db"]:
|
||||||
|
if mod in sys.modules:
|
||||||
|
del sys.modules[mod]
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app
|
||||||
|
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = runner.invoke(app.cli, ["--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
# Should have both custom command and fixtures
|
||||||
|
assert "hello" in result.output
|
||||||
|
assert "fixtures" in result.output
|
||||||
|
finally:
|
||||||
|
if str(tmp_path) in sys.path:
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
for mod in ["my_cli", "fixtures", "db"]:
|
||||||
|
if mod in sys.modules:
|
||||||
|
del sys.modules[mod]
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncCommand:
|
||||||
|
"""Tests for async_command decorator."""
|
||||||
|
|
||||||
|
def test_async_command_runs_coroutine(self):
|
||||||
|
"""async_command runs async function synchronously."""
|
||||||
|
|
||||||
|
@async_command
|
||||||
|
async def async_func(value: int) -> int:
|
||||||
|
return value * 2
|
||||||
|
|
||||||
|
result = async_func(21)
|
||||||
|
assert result == 42
|
||||||
|
|
||||||
|
def test_async_command_preserves_signature(self):
|
||||||
|
"""async_command preserves function signature."""
|
||||||
|
|
||||||
|
@async_command
|
||||||
|
async def async_func(name: str, count: int = 1) -> str:
|
||||||
|
return f"{name} x {count}"
|
||||||
|
|
||||||
|
result = async_func("test", count=3)
|
||||||
|
assert result == "test x 3"
|
||||||
|
|
||||||
|
def test_async_command_preserves_docstring(self):
|
||||||
|
"""async_command preserves function docstring."""
|
||||||
|
|
||||||
|
@async_command
|
||||||
|
async def async_func() -> None:
|
||||||
|
"""This is a docstring."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert async_func.__doc__ == """This is a docstring."""
|
||||||
|
|
||||||
|
def test_async_command_preserves_name(self):
|
||||||
|
"""async_command preserves function name."""
|
||||||
|
|
||||||
|
@async_command
|
||||||
|
async def my_async_function() -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert my_async_function.__name__ == "my_async_function"
|
||||||
1584
tests/test_crud.py
1584
tests/test_crud.py
File diff suppressed because it is too large
Load Diff
434
tests/test_crud_search.py
Normal file
434
tests/test_crud_search.py
Normal file
@@ -0,0 +1,434 @@
|
|||||||
|
"""Tests for CRUD search functionality."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from fastapi_toolsets.crud import SearchConfig, get_searchable_fields
|
||||||
|
from fastapi_toolsets.schemas import OffsetPagination
|
||||||
|
|
||||||
|
from .conftest import (
|
||||||
|
Role,
|
||||||
|
RoleCreate,
|
||||||
|
RoleCrud,
|
||||||
|
User,
|
||||||
|
UserCreate,
|
||||||
|
UserCrud,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPaginateSearch:
|
||||||
|
"""Tests for paginate() with search."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_single_column(self, db_session: AsyncSession):
|
||||||
|
"""Search on a single direct column."""
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="john_doe", email="john@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="jane_doe", email="jane@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="bob_smith", email="bob@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search="doe",
|
||||||
|
search_fields=[User.username],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_multiple_columns(self, db_session: AsyncSession):
|
||||||
|
"""Search across multiple columns (OR logic)."""
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="alice@company.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="company_bob", email="bob@other.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search="company",
|
||||||
|
search_fields=[User.username, User.email],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_relationship_depth1(self, db_session: AsyncSession):
|
||||||
|
"""Search through a relationship (depth 1)."""
|
||||||
|
admin_role = await RoleCrud.create(db_session, RoleCreate(name="administrator"))
|
||||||
|
user_role = await RoleCrud.create(db_session, RoleCreate(name="basic_user"))
|
||||||
|
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="admin1", email="a1@test.com", role_id=admin_role.id),
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="admin2", email="a2@test.com", role_id=admin_role.id),
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="user1", email="u1@test.com", role_id=user_role.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search="admin",
|
||||||
|
search_fields=[(User.role, Role.name)],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_mixed_direct_and_relation(self, db_session: AsyncSession):
|
||||||
|
"""Search combining direct columns and relationships."""
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="john", email="john@test.com", role_id=role.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search "admin" in username OR role.name
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search="admin",
|
||||||
|
search_fields=[User.username, (User.role, Role.name)],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 1
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_case_insensitive(self, db_session: AsyncSession):
|
||||||
|
"""Search is case-insensitive by default."""
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="JohnDoe", email="j@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search="johndoe",
|
||||||
|
search_fields=[User.username],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 1
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_case_sensitive(self, db_session: AsyncSession):
|
||||||
|
"""Case-sensitive search with SearchConfig."""
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="JohnDoe", email="j@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not find (case mismatch)
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search=SearchConfig(query="johndoe", case_sensitive=True),
|
||||||
|
search_fields=[User.username],
|
||||||
|
)
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 0
|
||||||
|
|
||||||
|
# Should find (case match)
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search=SearchConfig(query="JohnDoe", case_sensitive=True),
|
||||||
|
search_fields=[User.username],
|
||||||
|
)
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 1
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_empty_query(self, db_session: AsyncSession):
|
||||||
|
"""Empty search returns all results."""
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="user1", email="u1@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="user2", email="u2@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserCrud.paginate(db_session, search="")
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 2
|
||||||
|
|
||||||
|
result = await UserCrud.paginate(db_session, search=None)
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_with_existing_filters(self, db_session: AsyncSession):
|
||||||
|
"""Search combines with existing filters (AND)."""
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="active_john", email="aj@test.com", is_active=True),
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="inactive_john", email="ij@test.com", is_active=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
filters=[User.is_active == True], # noqa: E712
|
||||||
|
search="john",
|
||||||
|
search_fields=[User.username],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 1
|
||||||
|
assert result.data[0].username == "active_john"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_auto_detect_fields(self, db_session: AsyncSession):
|
||||||
|
"""Auto-detect searchable fields when not specified."""
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="findme", email="other@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserCrud.paginate(db_session, search="findme")
|
||||||
|
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 1
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_no_results(self, db_session: AsyncSession):
|
||||||
|
"""Search with no matching results."""
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="john", email="j@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search="nonexistent",
|
||||||
|
search_fields=[User.username],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 0
|
||||||
|
assert result.data == []
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_with_pagination(self, db_session: AsyncSession):
|
||||||
|
"""Search respects pagination parameters."""
|
||||||
|
for i in range(15):
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username=f"user_{i}", email=f"user{i}@test.com"),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search="user_",
|
||||||
|
search_fields=[User.username],
|
||||||
|
page=1,
|
||||||
|
items_per_page=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 15
|
||||||
|
assert len(result.data) == 5
|
||||||
|
assert result.pagination.has_more is True
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_null_relationship(self, db_session: AsyncSession):
|
||||||
|
"""Users without relationship are included (outerjoin)."""
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="with_role", email="wr@test.com", role_id=role.id),
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="no_role", email="nr@test.com", role_id=None),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search in username, not in role
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search="role",
|
||||||
|
search_fields=[User.username],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_with_order_by(self, db_session: AsyncSession):
|
||||||
|
"""Search works with order_by parameter."""
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="charlie", email="c@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="a@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="bob", email="b@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search="@test.com",
|
||||||
|
search_fields=[User.email],
|
||||||
|
order_by=User.username,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 3
|
||||||
|
usernames = [u.username for u in result.data]
|
||||||
|
assert usernames == ["alice", "bob", "charlie"]
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_non_string_column(self, db_session: AsyncSession):
|
||||||
|
"""Search on non-string columns (e.g., UUID) works via cast."""
|
||||||
|
user_id = uuid.UUID("12345678-1234-5678-1234-567812345678")
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(id=user_id, username="john", email="john@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="jane", email="jane@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search by UUID (partial match)
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search="12345678",
|
||||||
|
search_fields=[User.id, User.username],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 1
|
||||||
|
assert result.data[0].id == user_id
|
||||||
|
|
||||||
|
|
||||||
|
class TestSearchConfig:
|
||||||
|
"""Tests for SearchConfig options."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_match_mode_all(self, db_session: AsyncSession):
|
||||||
|
"""match_mode='all' requires all fields to match (AND)."""
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="john_test", email="john_test@company.com"),
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="john_other", email="other@example.com"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 'john' must be in username AND email
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search=SearchConfig(query="john", match_mode="all"),
|
||||||
|
search_fields=[User.username, User.email],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 1
|
||||||
|
assert result.data[0].username == "john_test"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_config_with_fields(self, db_session: AsyncSession):
|
||||||
|
"""SearchConfig can specify fields directly."""
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="test", email="findme@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search=SearchConfig(query="findme", fields=[User.email]),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestNoSearchableFieldsError:
|
||||||
|
"""Tests for NoSearchableFieldsError exception."""
|
||||||
|
|
||||||
|
def test_error_is_api_exception(self):
|
||||||
|
"""NoSearchableFieldsError inherits from ApiException."""
|
||||||
|
from fastapi_toolsets.exceptions import ApiException, NoSearchableFieldsError
|
||||||
|
|
||||||
|
assert issubclass(NoSearchableFieldsError, ApiException)
|
||||||
|
|
||||||
|
def test_error_has_api_error_fields(self):
|
||||||
|
"""NoSearchableFieldsError has proper api_error configuration."""
|
||||||
|
from fastapi_toolsets.exceptions import NoSearchableFieldsError
|
||||||
|
|
||||||
|
assert NoSearchableFieldsError.api_error.code == 400
|
||||||
|
assert NoSearchableFieldsError.api_error.err_code == "SEARCH-400"
|
||||||
|
|
||||||
|
def test_error_message_contains_model_name(self):
|
||||||
|
"""Error message includes the model name."""
|
||||||
|
from fastapi_toolsets.exceptions import NoSearchableFieldsError
|
||||||
|
|
||||||
|
error = NoSearchableFieldsError(User)
|
||||||
|
assert "User" in str(error)
|
||||||
|
assert error.model is User
|
||||||
|
|
||||||
|
def test_error_raised_when_no_fields(self):
|
||||||
|
"""Error is raised when search has no searchable fields."""
|
||||||
|
from sqlalchemy import Integer
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||||
|
|
||||||
|
from fastapi_toolsets.crud.search import build_search_filters
|
||||||
|
from fastapi_toolsets.exceptions import NoSearchableFieldsError
|
||||||
|
|
||||||
|
# Model with no String columns
|
||||||
|
class NoStringBase(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class NoStringModel(NoStringBase):
|
||||||
|
__tablename__ = "no_strings"
|
||||||
|
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||||
|
count: Mapped[int] = mapped_column(Integer, default=0)
|
||||||
|
|
||||||
|
with pytest.raises(NoSearchableFieldsError) as exc_info:
|
||||||
|
build_search_filters(NoStringModel, "test")
|
||||||
|
|
||||||
|
assert exc_info.value.model is NoStringModel
|
||||||
|
assert "NoStringModel" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetSearchableFields:
|
||||||
|
"""Tests for auto-detection of searchable fields."""
|
||||||
|
|
||||||
|
def test_detects_string_columns(self):
|
||||||
|
"""Detects String columns on the model."""
|
||||||
|
fields = get_searchable_fields(User, include_relationships=False)
|
||||||
|
|
||||||
|
# Should include username and email (String), not id or is_active
|
||||||
|
field_names = [str(f) for f in fields]
|
||||||
|
assert any("username" in f for f in field_names)
|
||||||
|
assert any("email" in f for f in field_names)
|
||||||
|
assert not any("id" in f and "role_id" not in f for f in field_names)
|
||||||
|
assert not any("is_active" in f for f in field_names)
|
||||||
|
|
||||||
|
def test_detects_relationship_fields(self):
|
||||||
|
"""Detects String fields on related models."""
|
||||||
|
fields = get_searchable_fields(User, include_relationships=True)
|
||||||
|
|
||||||
|
# Should include (User.role, Role.name)
|
||||||
|
has_role_name = any(isinstance(f, tuple) and len(f) == 2 for f in fields)
|
||||||
|
assert has_role_name
|
||||||
|
|
||||||
|
def test_skips_collection_relationships(self):
|
||||||
|
"""Skips one-to-many relationships."""
|
||||||
|
fields = get_searchable_fields(Role, include_relationships=True)
|
||||||
|
|
||||||
|
# Role.users is a collection, should not be included
|
||||||
|
field_strs = [str(f) for f in fields]
|
||||||
|
assert not any("users" in f for f in field_strs)
|
||||||
102
tests/test_db.py
102
tests/test_db.py
@@ -1,5 +1,8 @@
|
|||||||
"""Tests for fastapi_toolsets.db module."""
|
"""Tests for fastapi_toolsets.db module."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
@@ -9,6 +12,7 @@ from fastapi_toolsets.db import (
|
|||||||
create_db_dependency,
|
create_db_dependency,
|
||||||
get_transaction,
|
get_transaction,
|
||||||
lock_tables,
|
lock_tables,
|
||||||
|
wait_for_row_change,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User
|
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User
|
||||||
@@ -241,3 +245,101 @@ class TestLockTables:
|
|||||||
|
|
||||||
result = await RoleCrud.first(db_session, [Role.name == "lock_rollback_role"])
|
result = await RoleCrud.first(db_session, [Role.name == "lock_rollback_role"])
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestWaitForRowChange:
|
||||||
|
"""Tests for wait_for_row_change polling function."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_detects_update(self, db_session: AsyncSession, engine):
|
||||||
|
"""Returns updated instance when a column value changes."""
|
||||||
|
role = Role(name="watch_role")
|
||||||
|
db_session.add(role)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
async def update_later():
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
async with factory() as other:
|
||||||
|
r = await other.get(Role, role.id)
|
||||||
|
assert r is not None
|
||||||
|
r.name = "updated_role"
|
||||||
|
await other.commit()
|
||||||
|
|
||||||
|
update_task = asyncio.create_task(update_later())
|
||||||
|
result = await wait_for_row_change(db_session, Role, role.id, interval=0.05)
|
||||||
|
await update_task
|
||||||
|
|
||||||
|
assert result.name == "updated_role"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_watches_specific_columns(self, db_session: AsyncSession, engine):
|
||||||
|
"""Only triggers on changes to specified columns."""
|
||||||
|
user = User(username="testuser", email="test@example.com")
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
async def update_later():
|
||||||
|
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
# First: change email (not watched) — should not trigger
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
async with factory() as other:
|
||||||
|
u = await other.get(User, user.id)
|
||||||
|
assert u is not None
|
||||||
|
u.email = "new@example.com"
|
||||||
|
await other.commit()
|
||||||
|
# Second: change username (watched) — should trigger
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
async with factory() as other:
|
||||||
|
u = await other.get(User, user.id)
|
||||||
|
assert u is not None
|
||||||
|
u.username = "newuser"
|
||||||
|
await other.commit()
|
||||||
|
|
||||||
|
update_task = asyncio.create_task(update_later())
|
||||||
|
result = await wait_for_row_change(
|
||||||
|
db_session, User, user.id, columns=["username"], interval=0.05
|
||||||
|
)
|
||||||
|
await update_task
|
||||||
|
|
||||||
|
assert result.username == "newuser"
|
||||||
|
assert result.email == "new@example.com"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_nonexistent_row_raises(self, db_session: AsyncSession):
|
||||||
|
"""Raises LookupError when the row does not exist."""
|
||||||
|
fake_id = uuid.uuid4()
|
||||||
|
with pytest.raises(LookupError, match="not found"):
|
||||||
|
await wait_for_row_change(db_session, Role, fake_id, interval=0.05)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_timeout_raises(self, db_session: AsyncSession):
|
||||||
|
"""Raises TimeoutError when no change is detected within timeout."""
|
||||||
|
role = Role(name="timeout_role")
|
||||||
|
db_session.add(role)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
with pytest.raises(TimeoutError):
|
||||||
|
await wait_for_row_change(
|
||||||
|
db_session, Role, role.id, interval=0.05, timeout=0.2
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_deleted_row_raises(self, db_session: AsyncSession, engine):
|
||||||
|
"""Raises LookupError when the row is deleted during polling."""
|
||||||
|
role = Role(name="delete_role")
|
||||||
|
db_session.add(role)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
async def delete_later():
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
async with factory() as other:
|
||||||
|
r = await other.get(Role, role.id)
|
||||||
|
await other.delete(r)
|
||||||
|
await other.commit()
|
||||||
|
|
||||||
|
delete_task = asyncio.create_task(delete_later())
|
||||||
|
with pytest.raises(LookupError):
|
||||||
|
await wait_for_row_change(db_session, Role, role.id, interval=0.05)
|
||||||
|
await delete_task
|
||||||
|
|||||||
186
tests/test_dependencies.py
Normal file
186
tests/test_dependencies.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
"""Tests for fastapi_toolsets.dependencies module."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import uuid
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.params import Depends
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from fastapi_toolsets.dependencies import BodyDependency, PathDependency
|
||||||
|
|
||||||
|
from .conftest import Role, RoleCreate, RoleCrud, User
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""Mock session dependency for testing."""
|
||||||
|
yield None
|
||||||
|
|
||||||
|
|
||||||
|
class TestPathDependency:
|
||||||
|
"""Tests for PathDependency factory."""
|
||||||
|
|
||||||
|
def test_returns_depends_instance(self):
|
||||||
|
"""PathDependency returns a Depends instance."""
|
||||||
|
dep = PathDependency(Role, Role.id, session_dep=mock_get_db)
|
||||||
|
assert isinstance(dep, Depends)
|
||||||
|
|
||||||
|
def test_signature_has_default_param_name(self):
|
||||||
|
"""PathDependency uses model_field as default param name."""
|
||||||
|
dep = cast(Any, PathDependency(Role, Role.id, session_dep=mock_get_db))
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
params = list(sig.parameters.keys())
|
||||||
|
|
||||||
|
assert "role_id" in params
|
||||||
|
assert "session" in params
|
||||||
|
|
||||||
|
def test_signature_has_correct_type_annotation(self):
|
||||||
|
"""PathDependency uses field's python type for annotation."""
|
||||||
|
dep = cast(Any, PathDependency(Role, Role.id, session_dep=mock_get_db))
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
|
assert sig.parameters["role_id"].annotation == uuid.UUID
|
||||||
|
assert sig.parameters["session"].annotation == AsyncSession
|
||||||
|
|
||||||
|
def test_signature_session_has_depends_default(self):
|
||||||
|
"""PathDependency session param has Depends as default."""
|
||||||
|
dep = cast(Any, PathDependency(Role, Role.id, session_dep=mock_get_db))
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
|
assert isinstance(sig.parameters["session"].default, Depends)
|
||||||
|
|
||||||
|
def test_custom_param_name_in_signature(self):
|
||||||
|
"""PathDependency uses custom param_name in signature."""
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
PathDependency(
|
||||||
|
Role, Role.id, session_dep=mock_get_db, param_name="role_uuid"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
params = list(sig.parameters.keys())
|
||||||
|
|
||||||
|
assert "role_uuid" in params
|
||||||
|
assert "id" not in params
|
||||||
|
|
||||||
|
def test_string_field_type(self):
|
||||||
|
"""PathDependency handles string field types."""
|
||||||
|
dep = cast(Any, PathDependency(User, User.username, session_dep=mock_get_db))
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
|
assert sig.parameters["user_username"].annotation is str
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_dependency_fetches_object(self, db_session):
|
||||||
|
"""PathDependency inner function fetches object from database."""
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="test_role"))
|
||||||
|
|
||||||
|
dep = cast(Any, PathDependency(Role, Role.id, session_dep=mock_get_db))
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
result = await func(session=db_session, role_id=role.id)
|
||||||
|
|
||||||
|
assert result.id == role.id
|
||||||
|
assert result.name == "test_role"
|
||||||
|
|
||||||
|
|
||||||
|
class TestBodyDependency:
|
||||||
|
"""Tests for BodyDependency factory."""
|
||||||
|
|
||||||
|
def test_returns_depends_instance(self):
|
||||||
|
"""BodyDependency returns a Depends instance."""
|
||||||
|
dep = BodyDependency(
|
||||||
|
Role, Role.id, session_dep=mock_get_db, body_field="role_id"
|
||||||
|
)
|
||||||
|
assert isinstance(dep, Depends)
|
||||||
|
|
||||||
|
def test_signature_has_body_field_as_param(self):
|
||||||
|
"""BodyDependency uses body_field as param name."""
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
BodyDependency(
|
||||||
|
Role, Role.id, session_dep=mock_get_db, body_field="role_id"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
params = list(sig.parameters.keys())
|
||||||
|
|
||||||
|
assert "role_id" in params
|
||||||
|
assert "session" in params
|
||||||
|
|
||||||
|
def test_signature_has_correct_type_annotation(self):
|
||||||
|
"""BodyDependency uses field's python type for annotation."""
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
BodyDependency(
|
||||||
|
Role, Role.id, session_dep=mock_get_db, body_field="role_id"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
|
assert sig.parameters["role_id"].annotation == uuid.UUID
|
||||||
|
assert sig.parameters["session"].annotation == AsyncSession
|
||||||
|
|
||||||
|
def test_signature_session_has_depends_default(self):
|
||||||
|
"""BodyDependency session param has Depends as default."""
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
BodyDependency(
|
||||||
|
Role, Role.id, session_dep=mock_get_db, body_field="role_id"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
|
assert isinstance(sig.parameters["session"].default, Depends)
|
||||||
|
|
||||||
|
def test_different_body_field_name(self):
|
||||||
|
"""BodyDependency can use any body_field name."""
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
BodyDependency(
|
||||||
|
User, User.id, session_dep=mock_get_db, body_field="user_uuid"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
params = list(sig.parameters.keys())
|
||||||
|
|
||||||
|
assert "user_uuid" in params
|
||||||
|
assert "id" not in params
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_dependency_fetches_object(self, db_session):
|
||||||
|
"""BodyDependency inner function fetches object from database."""
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="body_test_role"))
|
||||||
|
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
BodyDependency(
|
||||||
|
Role, Role.id, session_dep=mock_get_db, body_field="role_id"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
result = await func(session=db_session, role_id=role.id)
|
||||||
|
|
||||||
|
assert result.id == role.id
|
||||||
|
assert result.name == "body_test_role"
|
||||||
@@ -108,6 +108,24 @@ class TestGenerateErrorResponses:
|
|||||||
assert example["status"] == "FAIL"
|
assert example["status"] == "FAIL"
|
||||||
assert example["error_code"] == "RES-404"
|
assert example["error_code"] == "RES-404"
|
||||||
assert example["message"] == "Not Found"
|
assert example["message"] == "Not Found"
|
||||||
|
assert example["data"] is None
|
||||||
|
|
||||||
|
def test_response_example_with_data(self):
|
||||||
|
"""Generated response includes data when set on ApiError."""
|
||||||
|
|
||||||
|
class ErrorWithData(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400,
|
||||||
|
msg="Bad Request",
|
||||||
|
desc="Invalid input.",
|
||||||
|
err_code="BAD-400",
|
||||||
|
data={"details": "some context"},
|
||||||
|
)
|
||||||
|
|
||||||
|
responses = generate_error_responses(ErrorWithData)
|
||||||
|
example = responses[400]["content"]["application/json"]["example"]
|
||||||
|
|
||||||
|
assert example["data"] == {"details": "some context"}
|
||||||
|
|
||||||
|
|
||||||
class TestInitExceptionsHandlers:
|
class TestInitExceptionsHandlers:
|
||||||
@@ -137,6 +155,59 @@ class TestInitExceptionsHandlers:
|
|||||||
assert data["error_code"] == "RES-404"
|
assert data["error_code"] == "RES-404"
|
||||||
assert data["message"] == "Not Found"
|
assert data["message"] == "Not Found"
|
||||||
|
|
||||||
|
def test_handles_api_exception_without_data(self):
|
||||||
|
"""ApiException without data returns null data field."""
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
@app.get("/error")
|
||||||
|
async def raise_error():
|
||||||
|
raise NotFoundError()
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/error")
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
assert response.json()["data"] is None
|
||||||
|
|
||||||
|
def test_handles_api_exception_with_data(self):
|
||||||
|
"""ApiException with data returns the data payload."""
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
class CustomValidationError(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=422,
|
||||||
|
msg="Validation Error",
|
||||||
|
desc="1 validation error(s) detected",
|
||||||
|
err_code="CUSTOM-422",
|
||||||
|
data={
|
||||||
|
"errors": [
|
||||||
|
{
|
||||||
|
"field": "email",
|
||||||
|
"message": "invalid format",
|
||||||
|
"type": "value_error",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get("/error")
|
||||||
|
async def raise_error():
|
||||||
|
raise CustomValidationError()
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/error")
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
data = response.json()
|
||||||
|
assert data["data"] == {
|
||||||
|
"errors": [
|
||||||
|
{"field": "email", "message": "invalid format", "type": "value_error"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
assert data["error_code"] == "CUSTOM-422"
|
||||||
|
|
||||||
def test_handles_validation_error(self):
|
def test_handles_validation_error(self):
|
||||||
"""Handles validation errors with structured response."""
|
"""Handles validation errors with structured response."""
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""Tests for fastapi_toolsets.fixtures module."""
|
"""Tests for fastapi_toolsets.fixtures module."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
@@ -7,6 +9,7 @@ from fastapi_toolsets.fixtures import (
|
|||||||
Context,
|
Context,
|
||||||
FixtureRegistry,
|
FixtureRegistry,
|
||||||
LoadStrategy,
|
LoadStrategy,
|
||||||
|
get_obj_by_attr,
|
||||||
load_fixtures,
|
load_fixtures,
|
||||||
load_fixtures_by_context,
|
load_fixtures_by_context,
|
||||||
)
|
)
|
||||||
@@ -56,20 +59,22 @@ class TestFixtureRegistry:
|
|||||||
def test_register_with_decorator(self):
|
def test_register_with_decorator(self):
|
||||||
"""Register fixture with decorator."""
|
"""Register fixture with decorator."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
def roles():
|
def roles():
|
||||||
return [Role(id=1, name="admin")]
|
return [Role(id=role_id, name="admin")]
|
||||||
|
|
||||||
assert "roles" in [f.name for f in registry.get_all()]
|
assert "roles" in [f.name for f in registry.get_all()]
|
||||||
|
|
||||||
def test_register_with_custom_name(self):
|
def test_register_with_custom_name(self):
|
||||||
"""Register fixture with custom name."""
|
"""Register fixture with custom name."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register(name="custom_roles")
|
@registry.register(name="custom_roles")
|
||||||
def roles():
|
def roles():
|
||||||
return [Role(id=1, name="admin")]
|
return [Role(id=role_id, name="admin")]
|
||||||
|
|
||||||
fixture = registry.get("custom_roles")
|
fixture = registry.get("custom_roles")
|
||||||
assert fixture.name == "custom_roles"
|
assert fixture.name == "custom_roles"
|
||||||
@@ -77,14 +82,23 @@ class TestFixtureRegistry:
|
|||||||
def test_register_with_dependencies(self):
|
def test_register_with_dependencies(self):
|
||||||
"""Register fixture with dependencies."""
|
"""Register fixture with dependencies."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
def roles():
|
def roles():
|
||||||
return [Role(id=1, name="admin")]
|
return [Role(id=role_id, name="admin")]
|
||||||
|
|
||||||
@registry.register(depends_on=["roles"])
|
@registry.register(depends_on=["roles"])
|
||||||
def users():
|
def users():
|
||||||
return [User(id=1, username="admin", email="admin@test.com", role_id=1)]
|
return [
|
||||||
|
User(
|
||||||
|
id=user_id,
|
||||||
|
username="admin",
|
||||||
|
email="admin@test.com",
|
||||||
|
role_id=role_id,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
fixture = registry.get("users")
|
fixture = registry.get("users")
|
||||||
assert fixture.depends_on == ["roles"]
|
assert fixture.depends_on == ["roles"]
|
||||||
@@ -92,10 +106,11 @@ class TestFixtureRegistry:
|
|||||||
def test_register_with_contexts(self):
|
def test_register_with_contexts(self):
|
||||||
"""Register fixture with contexts."""
|
"""Register fixture with contexts."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register(contexts=[Context.TESTING])
|
@registry.register(contexts=[Context.TESTING])
|
||||||
def test_data():
|
def test_data():
|
||||||
return [Role(id=100, name="test")]
|
return [Role(id=role_id, name="test")]
|
||||||
|
|
||||||
fixture = registry.get("test_data")
|
fixture = registry.get("test_data")
|
||||||
assert Context.TESTING.value in fixture.contexts
|
assert Context.TESTING.value in fixture.contexts
|
||||||
@@ -144,6 +159,178 @@ class TestFixtureRegistry:
|
|||||||
assert names == {"test_data"}
|
assert names == {"test_data"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestIncludeRegistry:
|
||||||
|
"""Tests for FixtureRegistry.include_registry method."""
|
||||||
|
|
||||||
|
def test_include_empty_registry(self):
|
||||||
|
"""Include an empty registry does nothing."""
|
||||||
|
main_registry = FixtureRegistry()
|
||||||
|
other_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@main_registry.register
|
||||||
|
def roles():
|
||||||
|
return []
|
||||||
|
|
||||||
|
main_registry.include_registry(other_registry)
|
||||||
|
|
||||||
|
assert len(main_registry.get_all()) == 1
|
||||||
|
|
||||||
|
def test_include_registry_adds_fixtures(self):
|
||||||
|
"""Include registry adds all fixtures from the other registry."""
|
||||||
|
main_registry = FixtureRegistry()
|
||||||
|
other_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@main_registry.register
|
||||||
|
def roles():
|
||||||
|
return []
|
||||||
|
|
||||||
|
@other_registry.register
|
||||||
|
def users():
|
||||||
|
return []
|
||||||
|
|
||||||
|
@other_registry.register
|
||||||
|
def posts():
|
||||||
|
return []
|
||||||
|
|
||||||
|
main_registry.include_registry(other_registry)
|
||||||
|
|
||||||
|
names = {f.name for f in main_registry.get_all()}
|
||||||
|
assert names == {"roles", "users", "posts"}
|
||||||
|
|
||||||
|
def test_include_registry_preserves_dependencies(self):
|
||||||
|
"""Include registry preserves fixture dependencies."""
|
||||||
|
main_registry = FixtureRegistry()
|
||||||
|
other_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@main_registry.register
|
||||||
|
def roles():
|
||||||
|
return []
|
||||||
|
|
||||||
|
@other_registry.register(depends_on=["roles"])
|
||||||
|
def users():
|
||||||
|
return []
|
||||||
|
|
||||||
|
main_registry.include_registry(other_registry)
|
||||||
|
|
||||||
|
fixture = main_registry.get("users")
|
||||||
|
assert fixture.depends_on == ["roles"]
|
||||||
|
|
||||||
|
def test_include_registry_preserves_contexts(self):
|
||||||
|
"""Include registry preserves fixture contexts."""
|
||||||
|
main_registry = FixtureRegistry()
|
||||||
|
other_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@other_registry.register(contexts=[Context.TESTING, Context.DEVELOPMENT])
|
||||||
|
def test_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
main_registry.include_registry(other_registry)
|
||||||
|
|
||||||
|
fixture = main_registry.get("test_data")
|
||||||
|
assert Context.TESTING.value in fixture.contexts
|
||||||
|
assert Context.DEVELOPMENT.value in fixture.contexts
|
||||||
|
|
||||||
|
def test_include_registry_raises_on_duplicate(self):
|
||||||
|
"""Include registry raises ValueError on duplicate fixture names."""
|
||||||
|
main_registry = FixtureRegistry()
|
||||||
|
other_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@main_registry.register(name="roles")
|
||||||
|
def roles_main():
|
||||||
|
return []
|
||||||
|
|
||||||
|
@other_registry.register(name="roles")
|
||||||
|
def roles_other():
|
||||||
|
return []
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="already exists"):
|
||||||
|
main_registry.include_registry(other_registry)
|
||||||
|
|
||||||
|
def test_include_multiple_registries(self):
|
||||||
|
"""Include multiple registries sequentially."""
|
||||||
|
main_registry = FixtureRegistry()
|
||||||
|
dev_registry = FixtureRegistry()
|
||||||
|
test_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@main_registry.register
|
||||||
|
def base():
|
||||||
|
return []
|
||||||
|
|
||||||
|
@dev_registry.register
|
||||||
|
def dev_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
@test_registry.register
|
||||||
|
def test_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
main_registry.include_registry(dev_registry)
|
||||||
|
main_registry.include_registry(test_registry)
|
||||||
|
|
||||||
|
names = {f.name for f in main_registry.get_all()}
|
||||||
|
assert names == {"base", "dev_data", "test_data"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestDefaultContexts:
|
||||||
|
"""Tests for FixtureRegistry default contexts."""
|
||||||
|
|
||||||
|
def test_default_contexts_applied_to_fixtures(self):
|
||||||
|
"""Default contexts are applied when no contexts specified."""
|
||||||
|
registry = FixtureRegistry(contexts=[Context.TESTING])
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def test_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
fixture = registry.get("test_data")
|
||||||
|
assert fixture.contexts == [Context.TESTING.value]
|
||||||
|
|
||||||
|
def test_explicit_contexts_override_default(self):
|
||||||
|
"""Explicit contexts override default contexts."""
|
||||||
|
registry = FixtureRegistry(contexts=[Context.TESTING])
|
||||||
|
|
||||||
|
@registry.register(contexts=[Context.PRODUCTION])
|
||||||
|
def prod_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
fixture = registry.get("prod_data")
|
||||||
|
assert fixture.contexts == [Context.PRODUCTION.value]
|
||||||
|
|
||||||
|
def test_no_default_contexts_uses_base(self):
|
||||||
|
"""Without default contexts, BASE is used."""
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
fixture = registry.get("data")
|
||||||
|
assert fixture.contexts == [Context.BASE.value]
|
||||||
|
|
||||||
|
def test_multiple_default_contexts(self):
|
||||||
|
"""Multiple default contexts are applied."""
|
||||||
|
registry = FixtureRegistry(contexts=[Context.DEVELOPMENT, Context.TESTING])
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def dev_test_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
fixture = registry.get("dev_test_data")
|
||||||
|
assert Context.DEVELOPMENT.value in fixture.contexts
|
||||||
|
assert Context.TESTING.value in fixture.contexts
|
||||||
|
|
||||||
|
def test_default_contexts_with_string_values(self):
|
||||||
|
"""Default contexts work with string values."""
|
||||||
|
registry = FixtureRegistry(contexts=["custom_context"])
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def custom_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
fixture = registry.get("custom_data")
|
||||||
|
assert fixture.contexts == ["custom_context"]
|
||||||
|
|
||||||
|
|
||||||
class TestDependencyResolution:
|
class TestDependencyResolution:
|
||||||
"""Tests for fixture dependency resolution."""
|
"""Tests for fixture dependency resolution."""
|
||||||
|
|
||||||
@@ -243,12 +430,14 @@ class TestLoadFixtures:
|
|||||||
async def test_load_single_fixture(self, db_session: AsyncSession):
|
async def test_load_single_fixture(self, db_session: AsyncSession):
|
||||||
"""Load a single fixture."""
|
"""Load a single fixture."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id_1 = uuid.uuid4()
|
||||||
|
role_id_2 = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
def roles():
|
def roles():
|
||||||
return [
|
return [
|
||||||
Role(id=1, name="admin"),
|
Role(id=role_id_1, name="admin"),
|
||||||
Role(id=2, name="user"),
|
Role(id=role_id_2, name="user"),
|
||||||
]
|
]
|
||||||
|
|
||||||
result = await load_fixtures(db_session, registry, "roles")
|
result = await load_fixtures(db_session, registry, "roles")
|
||||||
@@ -265,14 +454,23 @@ class TestLoadFixtures:
|
|||||||
async def test_load_with_dependencies(self, db_session: AsyncSession):
|
async def test_load_with_dependencies(self, db_session: AsyncSession):
|
||||||
"""Load fixtures with dependencies."""
|
"""Load fixtures with dependencies."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
def roles():
|
def roles():
|
||||||
return [Role(id=1, name="admin")]
|
return [Role(id=role_id, name="admin")]
|
||||||
|
|
||||||
@registry.register(depends_on=["roles"])
|
@registry.register(depends_on=["roles"])
|
||||||
def users():
|
def users():
|
||||||
return [User(id=1, username="admin", email="admin@test.com", role_id=1)]
|
return [
|
||||||
|
User(
|
||||||
|
id=user_id,
|
||||||
|
username="admin",
|
||||||
|
email="admin@test.com",
|
||||||
|
role_id=role_id,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
result = await load_fixtures(db_session, registry, "users")
|
result = await load_fixtures(db_session, registry, "users")
|
||||||
|
|
||||||
@@ -288,10 +486,11 @@ class TestLoadFixtures:
|
|||||||
async def test_load_with_merge_strategy(self, db_session: AsyncSession):
|
async def test_load_with_merge_strategy(self, db_session: AsyncSession):
|
||||||
"""Load fixtures with MERGE strategy updates existing."""
|
"""Load fixtures with MERGE strategy updates existing."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
def roles():
|
def roles():
|
||||||
return [Role(id=1, name="admin")]
|
return [Role(id=role_id, name="admin")]
|
||||||
|
|
||||||
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
|
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
|
||||||
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
|
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
|
||||||
@@ -305,10 +504,11 @@ class TestLoadFixtures:
|
|||||||
async def test_load_with_skip_existing_strategy(self, db_session: AsyncSession):
|
async def test_load_with_skip_existing_strategy(self, db_session: AsyncSession):
|
||||||
"""Load fixtures with SKIP_EXISTING strategy."""
|
"""Load fixtures with SKIP_EXISTING strategy."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
def roles():
|
def roles():
|
||||||
return [Role(id=1, name="original")]
|
return [Role(id=role_id, name="original")]
|
||||||
|
|
||||||
await load_fixtures(
|
await load_fixtures(
|
||||||
db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING
|
db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING
|
||||||
@@ -316,7 +516,7 @@ class TestLoadFixtures:
|
|||||||
|
|
||||||
@registry.register(name="roles_updated")
|
@registry.register(name="roles_updated")
|
||||||
def roles_v2():
|
def roles_v2():
|
||||||
return [Role(id=1, name="updated")]
|
return [Role(id=role_id, name="updated")]
|
||||||
|
|
||||||
registry._fixtures["roles"] = registry._fixtures.pop("roles_updated")
|
registry._fixtures["roles"] = registry._fixtures.pop("roles_updated")
|
||||||
|
|
||||||
@@ -326,10 +526,77 @@ class TestLoadFixtures:
|
|||||||
|
|
||||||
from .conftest import RoleCrud
|
from .conftest import RoleCrud
|
||||||
|
|
||||||
role = await RoleCrud.first(db_session, [Role.id == 1])
|
role = await RoleCrud.first(db_session, [Role.id == role_id])
|
||||||
assert role is not None
|
assert role is not None
|
||||||
assert role.name == "original"
|
assert role.name == "original"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_load_with_insert_strategy(self, db_session: AsyncSession):
|
||||||
|
"""Load fixtures with INSERT strategy."""
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
role_id_1 = uuid.uuid4()
|
||||||
|
role_id_2 = uuid.uuid4()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def roles():
|
||||||
|
return [
|
||||||
|
Role(id=role_id_1, name="admin"),
|
||||||
|
Role(id=role_id_2, name="user"),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await load_fixtures(
|
||||||
|
db_session, registry, "roles", strategy=LoadStrategy.INSERT
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "roles" in result
|
||||||
|
assert len(result["roles"]) == 2
|
||||||
|
|
||||||
|
from .conftest import RoleCrud
|
||||||
|
|
||||||
|
count = await RoleCrud.count(db_session)
|
||||||
|
assert count == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_load_empty_fixture(self, db_session: AsyncSession):
|
||||||
|
"""Load a fixture that returns an empty list."""
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def empty_roles():
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = await load_fixtures(db_session, registry, "empty_roles")
|
||||||
|
|
||||||
|
assert "empty_roles" in result
|
||||||
|
assert result["empty_roles"] == []
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_load_multiple_fixtures_without_dependencies(
|
||||||
|
self, db_session: AsyncSession
|
||||||
|
):
|
||||||
|
"""Load multiple independent fixtures."""
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
role_id_1 = uuid.uuid4()
|
||||||
|
role_id_2 = uuid.uuid4()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def roles():
|
||||||
|
return [Role(id=role_id_1, name="admin")]
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def other_roles():
|
||||||
|
return [Role(id=role_id_2, name="user")]
|
||||||
|
|
||||||
|
result = await load_fixtures(db_session, registry, "roles", "other_roles")
|
||||||
|
|
||||||
|
assert "roles" in result
|
||||||
|
assert "other_roles" in result
|
||||||
|
|
||||||
|
from .conftest import RoleCrud
|
||||||
|
|
||||||
|
count = await RoleCrud.count(db_session)
|
||||||
|
assert count == 2
|
||||||
|
|
||||||
|
|
||||||
class TestLoadFixturesByContext:
|
class TestLoadFixturesByContext:
|
||||||
"""Tests for load_fixtures_by_context function."""
|
"""Tests for load_fixtures_by_context function."""
|
||||||
@@ -338,14 +605,16 @@ class TestLoadFixturesByContext:
|
|||||||
async def test_load_by_single_context(self, db_session: AsyncSession):
|
async def test_load_by_single_context(self, db_session: AsyncSession):
|
||||||
"""Load fixtures by single context."""
|
"""Load fixtures by single context."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
base_role_id = uuid.uuid4()
|
||||||
|
test_role_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register(contexts=[Context.BASE])
|
@registry.register(contexts=[Context.BASE])
|
||||||
def base_roles():
|
def base_roles():
|
||||||
return [Role(id=1, name="base_role")]
|
return [Role(id=base_role_id, name="base_role")]
|
||||||
|
|
||||||
@registry.register(contexts=[Context.TESTING])
|
@registry.register(contexts=[Context.TESTING])
|
||||||
def test_roles():
|
def test_roles():
|
||||||
return [Role(id=100, name="test_role")]
|
return [Role(id=test_role_id, name="test_role")]
|
||||||
|
|
||||||
await load_fixtures_by_context(db_session, registry, Context.BASE)
|
await load_fixtures_by_context(db_session, registry, Context.BASE)
|
||||||
|
|
||||||
@@ -354,7 +623,7 @@ class TestLoadFixturesByContext:
|
|||||||
count = await RoleCrud.count(db_session)
|
count = await RoleCrud.count(db_session)
|
||||||
assert count == 1
|
assert count == 1
|
||||||
|
|
||||||
role = await RoleCrud.first(db_session, [Role.id == 1])
|
role = await RoleCrud.first(db_session, [Role.id == base_role_id])
|
||||||
assert role is not None
|
assert role is not None
|
||||||
assert role.name == "base_role"
|
assert role.name == "base_role"
|
||||||
|
|
||||||
@@ -362,14 +631,16 @@ class TestLoadFixturesByContext:
|
|||||||
async def test_load_by_multiple_contexts(self, db_session: AsyncSession):
|
async def test_load_by_multiple_contexts(self, db_session: AsyncSession):
|
||||||
"""Load fixtures by multiple contexts."""
|
"""Load fixtures by multiple contexts."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
base_role_id = uuid.uuid4()
|
||||||
|
test_role_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register(contexts=[Context.BASE])
|
@registry.register(contexts=[Context.BASE])
|
||||||
def base_roles():
|
def base_roles():
|
||||||
return [Role(id=1, name="base_role")]
|
return [Role(id=base_role_id, name="base_role")]
|
||||||
|
|
||||||
@registry.register(contexts=[Context.TESTING])
|
@registry.register(contexts=[Context.TESTING])
|
||||||
def test_roles():
|
def test_roles():
|
||||||
return [Role(id=100, name="test_role")]
|
return [Role(id=test_role_id, name="test_role")]
|
||||||
|
|
||||||
await load_fixtures_by_context(
|
await load_fixtures_by_context(
|
||||||
db_session, registry, Context.BASE, Context.TESTING
|
db_session, registry, Context.BASE, Context.TESTING
|
||||||
@@ -384,14 +655,23 @@ class TestLoadFixturesByContext:
|
|||||||
async def test_load_context_with_dependencies(self, db_session: AsyncSession):
|
async def test_load_context_with_dependencies(self, db_session: AsyncSession):
|
||||||
"""Load context fixtures with cross-context dependencies."""
|
"""Load context fixtures with cross-context dependencies."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register(contexts=[Context.BASE])
|
@registry.register(contexts=[Context.BASE])
|
||||||
def roles():
|
def roles():
|
||||||
return [Role(id=1, name="admin")]
|
return [Role(id=role_id, name="admin")]
|
||||||
|
|
||||||
@registry.register(depends_on=["roles"], contexts=[Context.TESTING])
|
@registry.register(depends_on=["roles"], contexts=[Context.TESTING])
|
||||||
def test_users():
|
def test_users():
|
||||||
return [User(id=1, username="tester", email="test@test.com", role_id=1)]
|
return [
|
||||||
|
User(
|
||||||
|
id=user_id,
|
||||||
|
username="tester",
|
||||||
|
email="test@test.com",
|
||||||
|
role_id=role_id,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
await load_fixtures_by_context(db_session, registry, Context.TESTING)
|
await load_fixtures_by_context(db_session, registry, Context.TESTING)
|
||||||
|
|
||||||
@@ -399,3 +679,79 @@ class TestLoadFixturesByContext:
|
|||||||
|
|
||||||
assert await RoleCrud.count(db_session) == 1
|
assert await RoleCrud.count(db_session) == 1
|
||||||
assert await UserCrud.count(db_session) == 1
|
assert await UserCrud.count(db_session) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetObjByAttr:
|
||||||
|
"""Tests for get_obj_by_attr helper function."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Set up test fixtures for each test."""
|
||||||
|
self.registry = FixtureRegistry()
|
||||||
|
self.role_id_1 = uuid.uuid4()
|
||||||
|
self.role_id_2 = uuid.uuid4()
|
||||||
|
self.role_id_3 = uuid.uuid4()
|
||||||
|
self.user_id_1 = uuid.uuid4()
|
||||||
|
self.user_id_2 = uuid.uuid4()
|
||||||
|
|
||||||
|
role_id_1 = self.role_id_1
|
||||||
|
role_id_2 = self.role_id_2
|
||||||
|
role_id_3 = self.role_id_3
|
||||||
|
user_id_1 = self.user_id_1
|
||||||
|
user_id_2 = self.user_id_2
|
||||||
|
|
||||||
|
@self.registry.register
|
||||||
|
def roles() -> list[Role]:
|
||||||
|
return [
|
||||||
|
Role(id=role_id_1, name="admin"),
|
||||||
|
Role(id=role_id_2, name="user"),
|
||||||
|
Role(id=role_id_3, name="moderator"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@self.registry.register(depends_on=["roles"])
|
||||||
|
def users() -> list[User]:
|
||||||
|
return [
|
||||||
|
User(
|
||||||
|
id=user_id_1,
|
||||||
|
username="alice",
|
||||||
|
email="alice@example.com",
|
||||||
|
role_id=role_id_1,
|
||||||
|
),
|
||||||
|
User(
|
||||||
|
id=user_id_2,
|
||||||
|
username="bob",
|
||||||
|
email="bob@example.com",
|
||||||
|
role_id=role_id_1,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.roles = roles
|
||||||
|
self.users = users
|
||||||
|
|
||||||
|
def test_get_by_id(self):
|
||||||
|
"""Get an object by its id attribute."""
|
||||||
|
role = get_obj_by_attr(self.roles, "id", self.role_id_1)
|
||||||
|
assert role.name == "admin"
|
||||||
|
|
||||||
|
def test_get_user_by_username(self):
|
||||||
|
"""Get a user by username."""
|
||||||
|
user = get_obj_by_attr(self.users, "username", "bob")
|
||||||
|
assert user.id == self.user_id_2
|
||||||
|
assert user.email == "bob@example.com"
|
||||||
|
|
||||||
|
def test_returns_first_match(self):
|
||||||
|
"""Returns the first matching object when multiple could match."""
|
||||||
|
user = get_obj_by_attr(self.users, "role_id", self.role_id_1)
|
||||||
|
assert user.username == "alice"
|
||||||
|
|
||||||
|
def test_no_match_raises_stop_iteration(self):
|
||||||
|
"""Raises StopIteration with contextual message when no object matches."""
|
||||||
|
with pytest.raises(
|
||||||
|
StopIteration,
|
||||||
|
match="No object with name=nonexistent found in fixture 'roles'",
|
||||||
|
):
|
||||||
|
get_obj_by_attr(self.roles, "name", "nonexistent")
|
||||||
|
|
||||||
|
def test_no_match_on_wrong_value_type(self):
|
||||||
|
"""Raises StopIteration when value type doesn't match."""
|
||||||
|
with pytest.raises(StopIteration):
|
||||||
|
get_obj_by_attr(self.roles, "id", "not-a-uuid")
|
||||||
|
|||||||
@@ -1,57 +0,0 @@
|
|||||||
"""Tests for fastapi_toolsets.fixtures.utils."""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from fastapi_toolsets.fixtures import FixtureRegistry
|
|
||||||
from fastapi_toolsets.fixtures.utils import get_obj_by_attr
|
|
||||||
|
|
||||||
from .conftest import Role, User
|
|
||||||
|
|
||||||
registry = FixtureRegistry()
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
|
||||||
def roles() -> list[Role]:
|
|
||||||
return [
|
|
||||||
Role(id=1, name="admin"),
|
|
||||||
Role(id=2, name="user"),
|
|
||||||
Role(id=3, name="moderator"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register(depends_on=["roles"])
|
|
||||||
def users() -> list[User]:
|
|
||||||
return [
|
|
||||||
User(id=1, username="alice", email="alice@example.com", role_id=1),
|
|
||||||
User(id=2, username="bob", email="bob@example.com", role_id=1),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class TestGetObjByAttr:
|
|
||||||
"""Tests for get_obj_by_attr."""
|
|
||||||
|
|
||||||
def test_get_by_id(self):
|
|
||||||
"""Get an object by its id attribute."""
|
|
||||||
role = get_obj_by_attr(roles, "id", 1)
|
|
||||||
assert role.name == "admin"
|
|
||||||
|
|
||||||
def test_get_user_by_username(self):
|
|
||||||
"""Get a user by username."""
|
|
||||||
user = get_obj_by_attr(users, "username", "bob")
|
|
||||||
assert user.id == 2
|
|
||||||
assert user.email == "bob@example.com"
|
|
||||||
|
|
||||||
def test_returns_first_match(self):
|
|
||||||
"""Returns the first matching object when multiple could match."""
|
|
||||||
user = get_obj_by_attr(users, "role_id", 1)
|
|
||||||
assert user.username == "alice"
|
|
||||||
|
|
||||||
def test_no_match_raises_stop_iteration(self):
|
|
||||||
"""Raises StopIteration when no object matches."""
|
|
||||||
with pytest.raises(StopIteration):
|
|
||||||
get_obj_by_attr(roles, "name", "nonexistent")
|
|
||||||
|
|
||||||
def test_no_match_on_wrong_value_type(self):
|
|
||||||
"""Raises StopIteration when value type doesn't match."""
|
|
||||||
with pytest.raises(StopIteration):
|
|
||||||
get_obj_by_attr(roles, "id", "1")
|
|
||||||
229
tests/test_imports.py
Normal file
229
tests/test_imports.py
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
"""Tests for optional dependency import guards."""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from fastapi_toolsets._imports import require_extra
|
||||||
|
|
||||||
|
|
||||||
|
class TestRequireExtra:
|
||||||
|
"""Tests for the require_extra helper."""
|
||||||
|
|
||||||
|
def test_raises_import_error(self):
|
||||||
|
"""require_extra raises ImportError."""
|
||||||
|
with pytest.raises(ImportError):
|
||||||
|
require_extra(package="some_pkg", extra="some_extra")
|
||||||
|
|
||||||
|
def test_error_message_contains_package_name(self):
|
||||||
|
"""Error message mentions the missing package."""
|
||||||
|
with pytest.raises(ImportError, match="'prometheus_client'"):
|
||||||
|
require_extra(package="prometheus_client", extra="metrics")
|
||||||
|
|
||||||
|
def test_error_message_contains_install_instruction(self):
|
||||||
|
"""Error message contains the pip install command."""
|
||||||
|
with pytest.raises(
|
||||||
|
ImportError, match=r"pip install fastapi-toolsets\[metrics\]"
|
||||||
|
):
|
||||||
|
require_extra(package="prometheus_client", extra="metrics")
|
||||||
|
|
||||||
|
|
||||||
|
def _reload_without_package(module_path: str, blocked_packages: list[str]):
|
||||||
|
"""Reload a module while blocking specific package imports.
|
||||||
|
|
||||||
|
Removes the target module and its parents from sys.modules so they
|
||||||
|
get re-imported, and patches builtins.__import__ to raise ImportError
|
||||||
|
for *blocked_packages*.
|
||||||
|
"""
|
||||||
|
# Remove cached modules so they get re-imported
|
||||||
|
to_remove = [
|
||||||
|
key
|
||||||
|
for key in sys.modules
|
||||||
|
if key == module_path or key.startswith(module_path + ".")
|
||||||
|
]
|
||||||
|
saved = {}
|
||||||
|
for key in to_remove:
|
||||||
|
saved[key] = sys.modules.pop(key)
|
||||||
|
|
||||||
|
# Also remove parent package to force re-execution of __init__.py
|
||||||
|
parts = module_path.rsplit(".", 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
parent = parts[0]
|
||||||
|
parent_keys = [
|
||||||
|
key for key in sys.modules if key == parent or key.startswith(parent + ".")
|
||||||
|
]
|
||||||
|
for key in parent_keys:
|
||||||
|
if key not in saved:
|
||||||
|
saved[key] = sys.modules.pop(key)
|
||||||
|
|
||||||
|
original_import = (
|
||||||
|
__builtins__.__import__ if hasattr(__builtins__, "__import__") else __import__
|
||||||
|
)
|
||||||
|
|
||||||
|
def blocking_import(name, *args, **kwargs):
|
||||||
|
for blocked in blocked_packages:
|
||||||
|
if name == blocked or name.startswith(blocked + "."):
|
||||||
|
raise ImportError(f"Mocked: No module named '{name}'")
|
||||||
|
return original_import(name, *args, **kwargs)
|
||||||
|
|
||||||
|
return saved, blocking_import
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetricsImportGuard:
|
||||||
|
"""Tests for metrics module import guard when prometheus_client is missing."""
|
||||||
|
|
||||||
|
def test_registry_imports_without_prometheus(self):
|
||||||
|
"""Metric and MetricsRegistry are importable without prometheus_client."""
|
||||||
|
saved, blocking_import = _reload_without_package(
|
||||||
|
"fastapi_toolsets.metrics", ["prometheus_client"]
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with patch("builtins.__import__", side_effect=blocking_import):
|
||||||
|
mod = importlib.import_module("fastapi_toolsets.metrics")
|
||||||
|
# Registry types should be available (they're stdlib-only)
|
||||||
|
assert hasattr(mod, "Metric")
|
||||||
|
assert hasattr(mod, "MetricsRegistry")
|
||||||
|
finally:
|
||||||
|
# Restore original modules
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key.startswith("fastapi_toolsets.metrics"):
|
||||||
|
sys.modules.pop(key, None)
|
||||||
|
sys.modules.update(saved)
|
||||||
|
|
||||||
|
def test_init_metrics_stub_raises_without_prometheus(self):
|
||||||
|
"""init_metrics raises ImportError when prometheus_client is missing."""
|
||||||
|
saved, blocking_import = _reload_without_package(
|
||||||
|
"fastapi_toolsets.metrics", ["prometheus_client"]
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with patch("builtins.__import__", side_effect=blocking_import):
|
||||||
|
mod = importlib.import_module("fastapi_toolsets.metrics")
|
||||||
|
with pytest.raises(ImportError, match="prometheus_client"):
|
||||||
|
mod.init_metrics(None, None) # type: ignore[arg-type]
|
||||||
|
finally:
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key.startswith("fastapi_toolsets.metrics"):
|
||||||
|
sys.modules.pop(key, None)
|
||||||
|
sys.modules.update(saved)
|
||||||
|
|
||||||
|
def test_init_metrics_works_with_prometheus(self):
|
||||||
|
"""init_metrics is the real function when prometheus_client is available."""
|
||||||
|
from fastapi_toolsets.metrics import init_metrics
|
||||||
|
|
||||||
|
# Should be the real function, not a stub
|
||||||
|
assert init_metrics.__module__ == "fastapi_toolsets.metrics.handler"
|
||||||
|
|
||||||
|
|
||||||
|
class TestPytestImportGuard:
|
||||||
|
"""Tests for pytest module import guard when dependencies are missing."""
|
||||||
|
|
||||||
|
def test_import_raises_without_pytest_package(self):
|
||||||
|
"""Importing fastapi_toolsets.pytest raises when pytest is missing."""
|
||||||
|
saved, blocking_import = _reload_without_package(
|
||||||
|
"fastapi_toolsets.pytest", ["pytest"]
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with patch("builtins.__import__", side_effect=blocking_import):
|
||||||
|
with pytest.raises(ImportError, match="pytest"):
|
||||||
|
importlib.import_module("fastapi_toolsets.pytest")
|
||||||
|
finally:
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key.startswith("fastapi_toolsets.pytest"):
|
||||||
|
sys.modules.pop(key, None)
|
||||||
|
sys.modules.update(saved)
|
||||||
|
|
||||||
|
def test_import_raises_without_httpx(self):
|
||||||
|
"""Importing fastapi_toolsets.pytest raises when httpx is missing."""
|
||||||
|
saved, blocking_import = _reload_without_package(
|
||||||
|
"fastapi_toolsets.pytest", ["httpx"]
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with patch("builtins.__import__", side_effect=blocking_import):
|
||||||
|
with pytest.raises(ImportError, match="httpx"):
|
||||||
|
importlib.import_module("fastapi_toolsets.pytest")
|
||||||
|
finally:
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key.startswith("fastapi_toolsets.pytest"):
|
||||||
|
sys.modules.pop(key, None)
|
||||||
|
sys.modules.update(saved)
|
||||||
|
|
||||||
|
def test_all_exports_available_with_deps(self):
|
||||||
|
"""All expected exports are available when deps are installed."""
|
||||||
|
from fastapi_toolsets.pytest import (
|
||||||
|
cleanup_tables,
|
||||||
|
create_async_client,
|
||||||
|
create_db_session,
|
||||||
|
create_worker_database,
|
||||||
|
register_fixtures,
|
||||||
|
worker_database_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert callable(register_fixtures)
|
||||||
|
assert callable(create_async_client)
|
||||||
|
assert callable(create_db_session)
|
||||||
|
assert callable(create_worker_database)
|
||||||
|
assert callable(worker_database_url)
|
||||||
|
assert callable(cleanup_tables)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCliImportGuard:
|
||||||
|
"""Tests for CLI module import guard when typer is missing."""
|
||||||
|
|
||||||
|
def test_import_raises_without_typer(self):
|
||||||
|
"""Importing cli.app raises when typer is missing."""
|
||||||
|
saved, blocking_import = _reload_without_package(
|
||||||
|
"fastapi_toolsets.cli.app", ["typer"]
|
||||||
|
)
|
||||||
|
# Also remove cli.config since it imports typer too
|
||||||
|
config_keys = [
|
||||||
|
k for k in sys.modules if k.startswith("fastapi_toolsets.cli.config")
|
||||||
|
]
|
||||||
|
for key in config_keys:
|
||||||
|
if key not in saved:
|
||||||
|
saved[key] = sys.modules.pop(key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with patch("builtins.__import__", side_effect=blocking_import):
|
||||||
|
with pytest.raises(ImportError, match="typer"):
|
||||||
|
importlib.import_module("fastapi_toolsets.cli.app")
|
||||||
|
finally:
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key.startswith("fastapi_toolsets.cli.app") or key.startswith(
|
||||||
|
"fastapi_toolsets.cli.config"
|
||||||
|
):
|
||||||
|
sys.modules.pop(key, None)
|
||||||
|
sys.modules.update(saved)
|
||||||
|
|
||||||
|
def test_error_message_suggests_cli_extra(self):
|
||||||
|
"""Error message suggests installing the cli extra."""
|
||||||
|
saved, blocking_import = _reload_without_package(
|
||||||
|
"fastapi_toolsets.cli.app", ["typer"]
|
||||||
|
)
|
||||||
|
config_keys = [
|
||||||
|
k for k in sys.modules if k.startswith("fastapi_toolsets.cli.config")
|
||||||
|
]
|
||||||
|
for key in config_keys:
|
||||||
|
if key not in saved:
|
||||||
|
saved[key] = sys.modules.pop(key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with patch("builtins.__import__", side_effect=blocking_import):
|
||||||
|
with pytest.raises(
|
||||||
|
ImportError, match=r"pip install fastapi-toolsets\[cli\]"
|
||||||
|
):
|
||||||
|
importlib.import_module("fastapi_toolsets.cli.app")
|
||||||
|
finally:
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key.startswith("fastapi_toolsets.cli.app") or key.startswith(
|
||||||
|
"fastapi_toolsets.cli.config"
|
||||||
|
):
|
||||||
|
sys.modules.pop(key, None)
|
||||||
|
sys.modules.update(saved)
|
||||||
|
|
||||||
|
def test_async_command_imports_without_typer(self):
|
||||||
|
"""async_command is importable without typer (stdlib only)."""
|
||||||
|
from fastapi_toolsets.cli import async_command
|
||||||
|
|
||||||
|
assert callable(async_command)
|
||||||
118
tests/test_logger.py
Normal file
118
tests/test_logger.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from fastapi_toolsets.logger import (
|
||||||
|
DEFAULT_FORMAT,
|
||||||
|
UVICORN_LOGGERS,
|
||||||
|
configure_logging,
|
||||||
|
get_logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset_loggers():
|
||||||
|
"""Reset the root and uvicorn loggers after each test."""
|
||||||
|
yield
|
||||||
|
root = logging.getLogger()
|
||||||
|
root.handlers.clear()
|
||||||
|
root.setLevel(logging.WARNING)
|
||||||
|
for name in UVICORN_LOGGERS:
|
||||||
|
uv = logging.getLogger(name)
|
||||||
|
uv.handlers.clear()
|
||||||
|
uv.setLevel(logging.NOTSET)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigureLogging:
|
||||||
|
def test_sets_up_handler_and_format(self):
|
||||||
|
logger = configure_logging()
|
||||||
|
|
||||||
|
assert len(logger.handlers) == 1
|
||||||
|
handler = logger.handlers[0]
|
||||||
|
assert isinstance(handler, logging.StreamHandler)
|
||||||
|
assert handler.stream is sys.stdout
|
||||||
|
assert handler.formatter is not None
|
||||||
|
assert handler.formatter._fmt == DEFAULT_FORMAT
|
||||||
|
|
||||||
|
def test_default_level_is_info(self):
|
||||||
|
logger = configure_logging()
|
||||||
|
|
||||||
|
assert logger.level == logging.INFO
|
||||||
|
|
||||||
|
def test_custom_level_string(self):
|
||||||
|
logger = configure_logging(level="DEBUG")
|
||||||
|
|
||||||
|
assert logger.level == logging.DEBUG
|
||||||
|
|
||||||
|
def test_custom_level_int(self):
|
||||||
|
logger = configure_logging(level=logging.WARNING)
|
||||||
|
|
||||||
|
assert logger.level == logging.WARNING
|
||||||
|
|
||||||
|
def test_custom_format(self):
|
||||||
|
custom_fmt = "%(levelname)s: %(message)s"
|
||||||
|
logger = configure_logging(fmt=custom_fmt)
|
||||||
|
|
||||||
|
handler = logger.handlers[0]
|
||||||
|
assert handler.formatter is not None
|
||||||
|
assert handler.formatter._fmt == custom_fmt
|
||||||
|
|
||||||
|
def test_named_logger(self):
|
||||||
|
logger = configure_logging(logger_name="myapp")
|
||||||
|
|
||||||
|
assert logger.name == "myapp"
|
||||||
|
assert len(logger.handlers) == 1
|
||||||
|
|
||||||
|
def test_default_configures_root_logger(self):
|
||||||
|
logger = configure_logging()
|
||||||
|
|
||||||
|
assert logger is logging.getLogger()
|
||||||
|
|
||||||
|
def test_idempotent_no_duplicate_handlers(self):
|
||||||
|
configure_logging()
|
||||||
|
configure_logging()
|
||||||
|
logger = configure_logging()
|
||||||
|
|
||||||
|
assert len(logger.handlers) == 1
|
||||||
|
|
||||||
|
def test_configures_uvicorn_loggers(self):
|
||||||
|
configure_logging(level="DEBUG")
|
||||||
|
|
||||||
|
for name in UVICORN_LOGGERS:
|
||||||
|
uv_logger = logging.getLogger(name)
|
||||||
|
assert len(uv_logger.handlers) == 1
|
||||||
|
assert uv_logger.level == logging.DEBUG
|
||||||
|
handler = uv_logger.handlers[0]
|
||||||
|
assert handler.formatter is not None
|
||||||
|
assert handler.formatter._fmt == DEFAULT_FORMAT
|
||||||
|
|
||||||
|
def test_returns_configured_logger(self):
|
||||||
|
logger = configure_logging(logger_name="test.return")
|
||||||
|
|
||||||
|
assert isinstance(logger, logging.Logger)
|
||||||
|
assert logger.name == "test.return"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetLogger:
|
||||||
|
def test_returns_named_logger(self):
|
||||||
|
logger = get_logger("myapp.services")
|
||||||
|
|
||||||
|
assert isinstance(logger, logging.Logger)
|
||||||
|
assert logger.name == "myapp.services"
|
||||||
|
|
||||||
|
def test_returns_root_logger_when_none(self):
|
||||||
|
logger = get_logger(None)
|
||||||
|
|
||||||
|
assert logger is logging.getLogger()
|
||||||
|
|
||||||
|
def test_defaults_to_caller_module_name(self):
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
assert logger.name == __name__
|
||||||
|
|
||||||
|
def test_same_name_returns_same_logger(self):
|
||||||
|
a = get_logger("myapp")
|
||||||
|
b = get_logger("myapp")
|
||||||
|
|
||||||
|
assert a is b
|
||||||
519
tests/test_metrics.py
Normal file
519
tests/test_metrics.py
Normal file
@@ -0,0 +1,519 @@
|
|||||||
|
"""Tests for fastapi_toolsets.metrics module."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from prometheus_client import REGISTRY, CollectorRegistry, Counter, Gauge
|
||||||
|
|
||||||
|
from fastapi_toolsets.metrics import Metric, MetricsRegistry, init_metrics
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _clean_prometheus_registry():
|
||||||
|
"""Unregister test collectors from the global registry after each test."""
|
||||||
|
yield
|
||||||
|
collectors = list(REGISTRY._names_to_collectors.values())
|
||||||
|
for collector in collectors:
|
||||||
|
try:
|
||||||
|
REGISTRY.unregister(collector)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetric:
|
||||||
|
"""Tests for Metric dataclass."""
|
||||||
|
|
||||||
|
def test_default_collect_is_false(self):
|
||||||
|
"""Default collect is False (provider mode)."""
|
||||||
|
definition = Metric(name="test", func=lambda: None)
|
||||||
|
assert definition.collect is False
|
||||||
|
|
||||||
|
def test_collect_true(self):
|
||||||
|
"""Collect can be set to True (collector mode)."""
|
||||||
|
definition = Metric(name="test", func=lambda: None, collect=True)
|
||||||
|
assert definition.collect is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetricsRegistry:
|
||||||
|
"""Tests for MetricsRegistry class."""
|
||||||
|
|
||||||
|
def test_register_with_decorator(self):
|
||||||
|
"""Register metric with bare decorator."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def my_counter():
|
||||||
|
return Counter("test_counter", "A test counter")
|
||||||
|
|
||||||
|
names = [m.name for m in registry.get_all()]
|
||||||
|
assert "my_counter" in names
|
||||||
|
|
||||||
|
def test_register_with_custom_name(self):
|
||||||
|
"""Register metric with custom name."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register(name="custom_name")
|
||||||
|
def my_counter():
|
||||||
|
return Counter("test_counter_2", "A test counter")
|
||||||
|
|
||||||
|
definition = registry.get_all()[0]
|
||||||
|
assert definition.name == "custom_name"
|
||||||
|
|
||||||
|
def test_register_as_collector(self):
|
||||||
|
"""Register metric with collect=True."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def collect_something():
|
||||||
|
pass
|
||||||
|
|
||||||
|
definition = registry.get_all()[0]
|
||||||
|
assert definition.collect is True
|
||||||
|
|
||||||
|
def test_register_preserves_function(self):
|
||||||
|
"""Decorator returns the original function unchanged."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
def my_func():
|
||||||
|
return "original"
|
||||||
|
|
||||||
|
result = registry.register(my_func)
|
||||||
|
assert result is my_func
|
||||||
|
assert result() == "original"
|
||||||
|
|
||||||
|
def test_register_parameterized_preserves_function(self):
|
||||||
|
"""Parameterized decorator returns the original function unchanged."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
def my_func():
|
||||||
|
return "original"
|
||||||
|
|
||||||
|
result = registry.register(name="custom")(my_func)
|
||||||
|
assert result is my_func
|
||||||
|
assert result() == "original"
|
||||||
|
|
||||||
|
def test_get_all(self):
|
||||||
|
"""Get all registered metrics."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def metric_a():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def metric_b():
|
||||||
|
pass
|
||||||
|
|
||||||
|
names = {m.name for m in registry.get_all()}
|
||||||
|
assert names == {"metric_a", "metric_b"}
|
||||||
|
|
||||||
|
def test_get_providers(self):
|
||||||
|
"""Get only provider metrics (collect=False)."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def provider():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def collector():
|
||||||
|
pass
|
||||||
|
|
||||||
|
providers = registry.get_providers()
|
||||||
|
assert len(providers) == 1
|
||||||
|
assert providers[0].name == "provider"
|
||||||
|
|
||||||
|
def test_get_collectors(self):
|
||||||
|
"""Get only collector metrics (collect=True)."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def provider():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def collector():
|
||||||
|
pass
|
||||||
|
|
||||||
|
collectors = registry.get_collectors()
|
||||||
|
assert len(collectors) == 1
|
||||||
|
assert collectors[0].name == "collector"
|
||||||
|
|
||||||
|
def test_register_overwrites_same_name(self):
|
||||||
|
"""Registering with the same name overwrites the previous entry."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register(name="metric")
|
||||||
|
def first():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@registry.register(name="metric")
|
||||||
|
def second():
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert len(registry.get_all()) == 1
|
||||||
|
assert registry.get_all()[0].func is second
|
||||||
|
|
||||||
|
|
||||||
|
class TestIncludeRegistry:
|
||||||
|
"""Tests for MetricsRegistry.include_registry method."""
|
||||||
|
|
||||||
|
def test_include_empty_registry(self):
|
||||||
|
"""Include an empty registry does nothing."""
|
||||||
|
main = MetricsRegistry()
|
||||||
|
other = MetricsRegistry()
|
||||||
|
|
||||||
|
@main.register
|
||||||
|
def metric_a():
|
||||||
|
pass
|
||||||
|
|
||||||
|
main.include_registry(other)
|
||||||
|
assert len(main.get_all()) == 1
|
||||||
|
|
||||||
|
def test_include_registry_adds_metrics(self):
|
||||||
|
"""Include registry adds all metrics from the other registry."""
|
||||||
|
main = MetricsRegistry()
|
||||||
|
other = MetricsRegistry()
|
||||||
|
|
||||||
|
@main.register
|
||||||
|
def metric_a():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@other.register
|
||||||
|
def metric_b():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@other.register
|
||||||
|
def metric_c():
|
||||||
|
pass
|
||||||
|
|
||||||
|
main.include_registry(other)
|
||||||
|
names = {m.name for m in main.get_all()}
|
||||||
|
assert names == {"metric_a", "metric_b", "metric_c"}
|
||||||
|
|
||||||
|
def test_include_registry_preserves_collect_flag(self):
|
||||||
|
"""Include registry preserves the collect flag."""
|
||||||
|
main = MetricsRegistry()
|
||||||
|
other = MetricsRegistry()
|
||||||
|
|
||||||
|
@other.register(collect=True)
|
||||||
|
def collector():
|
||||||
|
pass
|
||||||
|
|
||||||
|
main.include_registry(other)
|
||||||
|
assert main.get_all()[0].collect is True
|
||||||
|
|
||||||
|
def test_include_registry_raises_on_duplicate(self):
|
||||||
|
"""Include registry raises ValueError on duplicate metric names."""
|
||||||
|
main = MetricsRegistry()
|
||||||
|
other = MetricsRegistry()
|
||||||
|
|
||||||
|
@main.register(name="metric")
|
||||||
|
def metric_main():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@other.register(name="metric")
|
||||||
|
def metric_other():
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="already exists"):
|
||||||
|
main.include_registry(other)
|
||||||
|
|
||||||
|
def test_include_multiple_registries(self):
|
||||||
|
"""Include multiple registries sequentially."""
|
||||||
|
main = MetricsRegistry()
|
||||||
|
sub1 = MetricsRegistry()
|
||||||
|
sub2 = MetricsRegistry()
|
||||||
|
|
||||||
|
@main.register
|
||||||
|
def base():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@sub1.register
|
||||||
|
def sub1_metric():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@sub2.register
|
||||||
|
def sub2_metric():
|
||||||
|
pass
|
||||||
|
|
||||||
|
main.include_registry(sub1)
|
||||||
|
main.include_registry(sub2)
|
||||||
|
|
||||||
|
names = {m.name for m in main.get_all()}
|
||||||
|
assert names == {"base", "sub1_metric", "sub2_metric"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestInitMetrics:
|
||||||
|
"""Tests for init_metrics function."""
|
||||||
|
|
||||||
|
def test_returns_app(self):
|
||||||
|
"""Returns the FastAPI app."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
result = init_metrics(app, registry)
|
||||||
|
assert result is app
|
||||||
|
|
||||||
|
def test_metrics_endpoint_responds(self):
|
||||||
|
"""The /metrics endpoint returns 200."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/metrics")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_metrics_endpoint_content_type(self):
|
||||||
|
"""The /metrics endpoint returns prometheus content type."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/metrics")
|
||||||
|
|
||||||
|
assert "text/plain" in response.headers["content-type"]
|
||||||
|
|
||||||
|
def test_custom_path(self):
|
||||||
|
"""Custom path is used for the metrics endpoint."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
init_metrics(app, registry, path="/custom-metrics")
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
assert client.get("/custom-metrics").status_code == 200
|
||||||
|
assert client.get("/metrics").status_code == 404
|
||||||
|
|
||||||
|
def test_providers_called_at_init(self):
|
||||||
|
"""Provider functions are called once at init time."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
mock = MagicMock()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def my_provider():
|
||||||
|
mock()
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
mock.assert_called_once()
|
||||||
|
|
||||||
|
def test_collectors_called_on_scrape(self):
|
||||||
|
"""Collector functions are called on each scrape."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
mock = MagicMock()
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def my_collector():
|
||||||
|
mock()
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
client.get("/metrics")
|
||||||
|
client.get("/metrics")
|
||||||
|
|
||||||
|
assert mock.call_count == 2
|
||||||
|
|
||||||
|
def test_collectors_not_called_at_init(self):
|
||||||
|
"""Collector functions are not called at init time."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
mock = MagicMock()
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def my_collector():
|
||||||
|
mock()
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
mock.assert_not_called()
|
||||||
|
|
||||||
|
def test_async_collectors_called_on_scrape(self):
|
||||||
|
"""Async collector functions are awaited on each scrape."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
mock = AsyncMock()
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
async def my_async_collector():
|
||||||
|
await mock()
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
client.get("/metrics")
|
||||||
|
client.get("/metrics")
|
||||||
|
|
||||||
|
assert mock.call_count == 2
|
||||||
|
|
||||||
|
def test_mixed_sync_and_async_collectors(self):
|
||||||
|
"""Both sync and async collectors are called on scrape."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
sync_mock = MagicMock()
|
||||||
|
async_mock = AsyncMock()
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def sync_collector():
|
||||||
|
sync_mock()
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
async def async_collector():
|
||||||
|
await async_mock()
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
client.get("/metrics")
|
||||||
|
|
||||||
|
sync_mock.assert_called_once()
|
||||||
|
async_mock.assert_called_once()
|
||||||
|
|
||||||
|
def test_registered_metrics_appear_in_output(self):
|
||||||
|
"""Metrics created by providers appear in /metrics output."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def my_gauge():
|
||||||
|
g = Gauge("test_gauge_value", "A test gauge")
|
||||||
|
g.set(42)
|
||||||
|
return g
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/metrics")
|
||||||
|
|
||||||
|
assert b"test_gauge_value" in response.content
|
||||||
|
assert b"42.0" in response.content
|
||||||
|
|
||||||
|
def test_endpoint_not_in_openapi_schema(self):
|
||||||
|
"""The /metrics endpoint is not included in the OpenAPI schema."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
schema = app.openapi()
|
||||||
|
assert "/metrics" not in schema.get("paths", {})
|
||||||
|
|
||||||
|
|
||||||
|
class TestMultiProcessMode:
|
||||||
|
"""Tests for multi-process Prometheus mode."""
|
||||||
|
|
||||||
|
def test_multiprocess_with_env_var(self):
|
||||||
|
"""Multi-process mode works when PROMETHEUS_MULTIPROC_DIR is set."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
os.environ["PROMETHEUS_MULTIPROC_DIR"] = tmpdir
|
||||||
|
try:
|
||||||
|
# Use a separate registry to avoid conflicts with default
|
||||||
|
prom_registry = CollectorRegistry()
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def mp_counter():
|
||||||
|
return Counter(
|
||||||
|
"mp_test_counter",
|
||||||
|
"A multiprocess counter",
|
||||||
|
registry=prom_registry,
|
||||||
|
)
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/metrics")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
finally:
|
||||||
|
del os.environ["PROMETHEUS_MULTIPROC_DIR"]
|
||||||
|
|
||||||
|
def test_single_process_without_env_var(self):
|
||||||
|
"""Single-process mode when PROMETHEUS_MULTIPROC_DIR is not set."""
|
||||||
|
os.environ.pop("PROMETHEUS_MULTIPROC_DIR", None)
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def sp_gauge():
|
||||||
|
g = Gauge("sp_test_gauge", "A single-process gauge")
|
||||||
|
g.set(99)
|
||||||
|
return g
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/metrics")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert b"sp_test_gauge" in response.content
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetricsIntegration:
|
||||||
|
"""Integration tests for the metrics module."""
|
||||||
|
|
||||||
|
def test_full_workflow(self):
|
||||||
|
"""Full workflow: registry, providers, collectors, endpoint."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
call_count = {"value": 0}
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def request_counter():
|
||||||
|
return Counter(
|
||||||
|
"integration_requests_total",
|
||||||
|
"Total requests",
|
||||||
|
["method"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def collect_uptime():
|
||||||
|
call_count["value"] += 1
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
response = client.get("/metrics")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert b"integration_requests_total" in response.content
|
||||||
|
assert call_count["value"] == 1
|
||||||
|
|
||||||
|
response = client.get("/metrics")
|
||||||
|
assert call_count["value"] == 2
|
||||||
|
|
||||||
|
def test_multiple_registries_merged(self):
|
||||||
|
"""Multiple registries can be merged and used together."""
|
||||||
|
app = FastAPI()
|
||||||
|
main = MetricsRegistry()
|
||||||
|
sub = MetricsRegistry()
|
||||||
|
|
||||||
|
@main.register
|
||||||
|
def main_gauge():
|
||||||
|
g = Gauge("main_gauge_val", "Main gauge")
|
||||||
|
g.set(1)
|
||||||
|
return g
|
||||||
|
|
||||||
|
@sub.register
|
||||||
|
def sub_gauge():
|
||||||
|
g = Gauge("sub_gauge_val", "Sub gauge")
|
||||||
|
g.set(2)
|
||||||
|
return g
|
||||||
|
|
||||||
|
main.include_registry(sub)
|
||||||
|
init_metrics(app, main)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/metrics")
|
||||||
|
|
||||||
|
assert b"main_gauge_val" in response.content
|
||||||
|
assert b"sub_gauge_val" in response.content
|
||||||
535
tests/test_pytest.py
Normal file
535
tests/test_pytest.py
Normal file
@@ -0,0 +1,535 @@
|
|||||||
|
"""Tests for fastapi_toolsets.pytest module."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import Depends, FastAPI
|
||||||
|
from httpx import AsyncClient
|
||||||
|
from sqlalchemy import select, text
|
||||||
|
from sqlalchemy.engine import make_url
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, selectinload
|
||||||
|
|
||||||
|
from fastapi_toolsets.fixtures import Context, FixtureRegistry
|
||||||
|
from fastapi_toolsets.pytest import (
|
||||||
|
cleanup_tables,
|
||||||
|
create_async_client,
|
||||||
|
create_db_session,
|
||||||
|
create_worker_database,
|
||||||
|
register_fixtures,
|
||||||
|
worker_database_url,
|
||||||
|
)
|
||||||
|
from fastapi_toolsets.pytest.utils import _get_xdist_worker
|
||||||
|
|
||||||
|
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
|
||||||
|
|
||||||
|
test_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
# Fixed UUIDs for test fixtures to allow consistent assertions
|
||||||
|
ROLE_ADMIN_ID = uuid.UUID("00000000-0000-0000-0000-000000001000")
|
||||||
|
ROLE_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000001001")
|
||||||
|
USER_ADMIN_ID = uuid.UUID("00000000-0000-0000-0000-000000002000")
|
||||||
|
USER_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000002001")
|
||||||
|
USER_EXTRA_ID = uuid.UUID("00000000-0000-0000-0000-000000002002")
|
||||||
|
|
||||||
|
|
||||||
|
@test_registry.register(contexts=[Context.BASE])
|
||||||
|
def roles() -> list[Role]:
|
||||||
|
return [
|
||||||
|
Role(id=ROLE_ADMIN_ID, name="plugin_admin"),
|
||||||
|
Role(id=ROLE_USER_ID, name="plugin_user"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@test_registry.register(depends_on=["roles"], contexts=[Context.BASE])
|
||||||
|
def users() -> list[User]:
|
||||||
|
return [
|
||||||
|
User(
|
||||||
|
id=USER_ADMIN_ID,
|
||||||
|
username="plugin_admin",
|
||||||
|
email="padmin@test.com",
|
||||||
|
role_id=ROLE_ADMIN_ID,
|
||||||
|
),
|
||||||
|
User(
|
||||||
|
id=USER_USER_ID,
|
||||||
|
username="plugin_user",
|
||||||
|
email="puser@test.com",
|
||||||
|
role_id=ROLE_USER_ID,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@test_registry.register(depends_on=["users"], contexts=[Context.TESTING])
|
||||||
|
def extra_users() -> list[User]:
|
||||||
|
return [
|
||||||
|
User(
|
||||||
|
id=USER_EXTRA_ID,
|
||||||
|
username="plugin_extra",
|
||||||
|
email="pextra@test.com",
|
||||||
|
role_id=ROLE_USER_ID,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
register_fixtures(test_registry, globals())
|
||||||
|
|
||||||
|
|
||||||
|
class TestRegisterFixtures:
|
||||||
|
"""Tests for register_fixtures function."""
|
||||||
|
|
||||||
|
def test_creates_fixtures_in_namespace(self):
|
||||||
|
"""Fixtures are created in the namespace."""
|
||||||
|
assert "fixture_roles" in globals()
|
||||||
|
assert "fixture_users" in globals()
|
||||||
|
assert "fixture_extra_users" in globals()
|
||||||
|
|
||||||
|
def test_fixtures_are_callable(self):
|
||||||
|
"""Created fixtures are callable."""
|
||||||
|
assert callable(globals()["fixture_roles"])
|
||||||
|
assert callable(globals()["fixture_users"])
|
||||||
|
|
||||||
|
|
||||||
|
class TestGeneratedFixtures:
|
||||||
|
"""Tests for the generated pytest fixtures."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_fixture_loads_data(
|
||||||
|
self, db_session: AsyncSession, fixture_roles: list[Role]
|
||||||
|
):
|
||||||
|
"""Fixture loads data into database and returns it."""
|
||||||
|
assert len(fixture_roles) == 2
|
||||||
|
assert fixture_roles[0].name == "plugin_admin"
|
||||||
|
assert fixture_roles[1].name == "plugin_user"
|
||||||
|
|
||||||
|
# Verify data is in database
|
||||||
|
count = await RoleCrud.count(db_session)
|
||||||
|
assert count == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_fixture_with_dependency(
|
||||||
|
self, db_session: AsyncSession, fixture_users: list[User]
|
||||||
|
):
|
||||||
|
"""Fixture with dependency loads parent fixture first."""
|
||||||
|
# fixture_users depends on fixture_roles
|
||||||
|
# Both should be loaded
|
||||||
|
assert len(fixture_users) == 2
|
||||||
|
|
||||||
|
# Roles should also be in database
|
||||||
|
roles_count = await RoleCrud.count(db_session)
|
||||||
|
assert roles_count == 2
|
||||||
|
|
||||||
|
# Users should be in database
|
||||||
|
users_count = await UserCrud.count(db_session)
|
||||||
|
assert users_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_fixture_returns_models(
|
||||||
|
self, db_session: AsyncSession, fixture_users: list[User]
|
||||||
|
):
|
||||||
|
"""Fixture returns actual model instances."""
|
||||||
|
user = fixture_users[0]
|
||||||
|
assert isinstance(user, User)
|
||||||
|
assert user.id == USER_ADMIN_ID
|
||||||
|
assert user.username == "plugin_admin"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_fixture_relationships_work(
|
||||||
|
self, db_session: AsyncSession, fixture_users: list[User]
|
||||||
|
):
|
||||||
|
"""Loaded fixtures have working relationships."""
|
||||||
|
# Load user with role relationship
|
||||||
|
user = await UserCrud.get(
|
||||||
|
db_session,
|
||||||
|
[User.id == USER_ADMIN_ID],
|
||||||
|
load_options=[selectinload(User.role)],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert user.role is not None
|
||||||
|
assert user.role.name == "plugin_admin"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_chained_dependencies(
|
||||||
|
self, db_session: AsyncSession, fixture_extra_users: list[User]
|
||||||
|
):
|
||||||
|
"""Chained dependencies are resolved correctly."""
|
||||||
|
# fixture_extra_users -> fixture_users -> fixture_roles
|
||||||
|
assert len(fixture_extra_users) == 1
|
||||||
|
|
||||||
|
# All fixtures should be loaded
|
||||||
|
roles_count = await RoleCrud.count(db_session)
|
||||||
|
users_count = await UserCrud.count(db_session)
|
||||||
|
|
||||||
|
assert roles_count == 2
|
||||||
|
assert users_count == 3 # 2 from users + 1 from extra_users
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_can_query_loaded_data(
|
||||||
|
self, db_session: AsyncSession, fixture_users: list[User]
|
||||||
|
):
|
||||||
|
"""Can query the loaded fixture data."""
|
||||||
|
# Get all users loaded by fixture
|
||||||
|
users = await UserCrud.get_multi(
|
||||||
|
db_session,
|
||||||
|
order_by=User.username,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(users) == 2
|
||||||
|
assert users[0].username == "plugin_admin"
|
||||||
|
assert users[1].username == "plugin_user"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_multiple_fixtures_in_same_test(
|
||||||
|
self,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
fixture_roles: list[Role],
|
||||||
|
fixture_users: list[User],
|
||||||
|
):
|
||||||
|
"""Multiple fixtures can be used in the same test."""
|
||||||
|
assert len(fixture_roles) == 2
|
||||||
|
assert len(fixture_users) == 2
|
||||||
|
|
||||||
|
# Both should be in database
|
||||||
|
roles = await RoleCrud.get_multi(db_session)
|
||||||
|
users = await UserCrud.get_multi(db_session)
|
||||||
|
|
||||||
|
assert len(roles) == 2
|
||||||
|
assert len(users) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateAsyncClient:
|
||||||
|
"""Tests for create_async_client helper."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_creates_working_client(self):
|
||||||
|
"""Client can make requests to the app."""
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health():
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
async with create_async_client(app) as client:
|
||||||
|
assert isinstance(client, AsyncClient)
|
||||||
|
response = await client.get("/health")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"status": "ok"}
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_custom_base_url(self):
|
||||||
|
"""Client uses custom base URL."""
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/test")
|
||||||
|
async def test_endpoint():
|
||||||
|
return {"url": "test"}
|
||||||
|
|
||||||
|
async with create_async_client(app, base_url="http://custom") as client:
|
||||||
|
assert str(client.base_url) == "http://custom"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_client_closes_properly(self):
|
||||||
|
"""Client is properly closed after context exit."""
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
async with create_async_client(app) as client:
|
||||||
|
client_ref = client
|
||||||
|
|
||||||
|
assert client_ref.is_closed
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_dependency_overrides_applied_and_cleaned(self):
|
||||||
|
"""Dependency overrides are applied during the context and removed after."""
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
async def original_dep() -> str:
|
||||||
|
return "original"
|
||||||
|
|
||||||
|
async def override_dep() -> str:
|
||||||
|
return "overridden"
|
||||||
|
|
||||||
|
@app.get("/dep")
|
||||||
|
async def dep_endpoint(value: str = Depends(original_dep)):
|
||||||
|
return {"value": value}
|
||||||
|
|
||||||
|
async with create_async_client(
|
||||||
|
app, dependency_overrides={original_dep: override_dep}
|
||||||
|
) as client:
|
||||||
|
response = await client.get("/dep")
|
||||||
|
assert response.json() == {"value": "overridden"}
|
||||||
|
|
||||||
|
# Overrides should be cleaned up
|
||||||
|
assert original_dep not in app.dependency_overrides
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateDbSession:
|
||||||
|
"""Tests for create_db_session helper."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_creates_working_session(self):
|
||||||
|
"""Session can perform database operations."""
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
async with create_db_session(DATABASE_URL, Base) as session:
|
||||||
|
assert isinstance(session, AsyncSession)
|
||||||
|
|
||||||
|
role = Role(id=role_id, name="test_helper_role")
|
||||||
|
session.add(role)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
result = await session.execute(select(Role).where(Role.id == role_id))
|
||||||
|
fetched = result.scalar_one()
|
||||||
|
assert fetched.name == "test_helper_role"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_tables_created_before_session(self):
|
||||||
|
"""Tables exist when session is yielded."""
|
||||||
|
async with create_db_session(DATABASE_URL, Base) as session:
|
||||||
|
# Should not raise - tables exist
|
||||||
|
result = await session.execute(select(Role))
|
||||||
|
assert result.all() == []
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_tables_dropped_after_session(self):
|
||||||
|
"""Tables are dropped after session closes when drop_tables=True."""
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||||
|
role = Role(id=role_id, name="will_be_dropped")
|
||||||
|
session.add(role)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Verify tables were dropped by creating new session
|
||||||
|
async with create_db_session(DATABASE_URL, Base) as session:
|
||||||
|
result = await session.execute(select(Role))
|
||||||
|
assert result.all() == []
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_tables_preserved_when_drop_disabled(self):
|
||||||
|
"""Tables are preserved when drop_tables=False."""
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
|
||||||
|
role = Role(id=role_id, name="preserved_role")
|
||||||
|
session.add(role)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Create another session without dropping
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
|
||||||
|
result = await session.execute(select(Role).where(Role.id == role_id))
|
||||||
|
fetched = result.scalar_one_or_none()
|
||||||
|
assert fetched is not None
|
||||||
|
assert fetched.name == "preserved_role"
|
||||||
|
|
||||||
|
# Cleanup: drop tables manually
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_cleanup_truncates_tables(self):
|
||||||
|
"""Tables are truncated after session closes when cleanup=True."""
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
async with create_db_session(
|
||||||
|
DATABASE_URL, Base, cleanup=True, drop_tables=False
|
||||||
|
) as session:
|
||||||
|
role = Role(id=role_id, name="will_be_cleaned")
|
||||||
|
session.add(role)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Data should have been truncated, but tables still exist
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||||
|
result = await session.execute(select(Role))
|
||||||
|
assert result.all() == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetXdistWorker:
|
||||||
|
"""Tests for _get_xdist_worker helper."""
|
||||||
|
|
||||||
|
def test_returns_default_test_db_without_env_var(
|
||||||
|
self, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
|
"""Returns default_test_db when PYTEST_XDIST_WORKER is not set."""
|
||||||
|
monkeypatch.delenv("PYTEST_XDIST_WORKER", raising=False)
|
||||||
|
assert _get_xdist_worker("my_default") == "my_default"
|
||||||
|
|
||||||
|
def test_returns_worker_name(self, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""Returns the worker name from the environment variable."""
|
||||||
|
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw0")
|
||||||
|
assert _get_xdist_worker("ignored") == "gw0"
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkerDatabaseUrl:
|
||||||
|
"""Tests for worker_database_url helper."""
|
||||||
|
|
||||||
|
def test_appends_default_test_db_without_xdist(
|
||||||
|
self, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
|
"""default_test_db is appended when not running under xdist."""
|
||||||
|
monkeypatch.delenv("PYTEST_XDIST_WORKER", raising=False)
|
||||||
|
url = "postgresql+asyncpg://user:pass@localhost:5432/mydb"
|
||||||
|
result = worker_database_url(url, default_test_db="fallback")
|
||||||
|
assert make_url(result).database == "mydb_fallback"
|
||||||
|
|
||||||
|
def test_appends_worker_id_to_database_name(self, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""Worker name is appended to the database name."""
|
||||||
|
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw0")
|
||||||
|
url = "postgresql+asyncpg://user:pass@localhost:5432/db"
|
||||||
|
result = worker_database_url(url, default_test_db="unused")
|
||||||
|
assert make_url(result).database == "db_gw0"
|
||||||
|
|
||||||
|
def test_preserves_url_components(self, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""Host, port, username, password, and driver are preserved."""
|
||||||
|
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw2")
|
||||||
|
url = "postgresql+asyncpg://myuser:secret@dbhost:6543/testdb"
|
||||||
|
result = make_url(worker_database_url(url, default_test_db="unused"))
|
||||||
|
|
||||||
|
assert result.drivername == "postgresql+asyncpg"
|
||||||
|
assert result.username == "myuser"
|
||||||
|
assert result.password == "secret"
|
||||||
|
assert result.host == "dbhost"
|
||||||
|
assert result.port == 6543
|
||||||
|
assert result.database == "testdb_gw2"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateWorkerDatabase:
|
||||||
|
"""Tests for create_worker_database context manager."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_creates_default_db_without_xdist(
|
||||||
|
self, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
|
"""Without xdist, creates a database suffixed with default_test_db."""
|
||||||
|
monkeypatch.delenv("PYTEST_XDIST_WORKER", raising=False)
|
||||||
|
default_test_db = "no_xdist_default"
|
||||||
|
expected_db = make_url(
|
||||||
|
worker_database_url(DATABASE_URL, default_test_db=default_test_db)
|
||||||
|
).database
|
||||||
|
|
||||||
|
async with create_worker_database(
|
||||||
|
DATABASE_URL, default_test_db=default_test_db
|
||||||
|
) as url:
|
||||||
|
assert make_url(url).database == expected_db
|
||||||
|
|
||||||
|
# Verify the database exists while inside the context
|
||||||
|
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
result = await conn.execute(
|
||||||
|
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||||
|
{"name": expected_db},
|
||||||
|
)
|
||||||
|
assert result.scalar() == 1
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
# After context exit the database should be dropped
|
||||||
|
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
result = await conn.execute(
|
||||||
|
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||||
|
{"name": expected_db},
|
||||||
|
)
|
||||||
|
assert result.scalar() is None
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_creates_and_drops_worker_database(
|
||||||
|
self, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
|
"""Worker database exists inside the context and is dropped after."""
|
||||||
|
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw_test_create")
|
||||||
|
expected_db = make_url(
|
||||||
|
worker_database_url(DATABASE_URL, default_test_db="unused")
|
||||||
|
).database
|
||||||
|
|
||||||
|
async with create_worker_database(DATABASE_URL) as url:
|
||||||
|
assert make_url(url).database == expected_db
|
||||||
|
|
||||||
|
# Verify the database exists while inside the context
|
||||||
|
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
result = await conn.execute(
|
||||||
|
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||||
|
{"name": expected_db},
|
||||||
|
)
|
||||||
|
assert result.scalar() == 1
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
# After context exit the database should be dropped
|
||||||
|
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
result = await conn.execute(
|
||||||
|
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||||
|
{"name": expected_db},
|
||||||
|
)
|
||||||
|
assert result.scalar() is None
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_cleans_up_stale_database(self, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""A pre-existing worker database is dropped and recreated."""
|
||||||
|
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw_test_stale")
|
||||||
|
expected_db = make_url(
|
||||||
|
worker_database_url(DATABASE_URL, default_test_db="unused")
|
||||||
|
).database
|
||||||
|
|
||||||
|
# Pre-create the database to simulate a stale leftover
|
||||||
|
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
await conn.execute(text(f"DROP DATABASE IF EXISTS {expected_db}"))
|
||||||
|
await conn.execute(text(f"CREATE DATABASE {expected_db}"))
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
# Should succeed despite the database already existing
|
||||||
|
async with create_worker_database(DATABASE_URL) as url:
|
||||||
|
assert make_url(url).database == expected_db
|
||||||
|
|
||||||
|
# Verify cleanup after context exit
|
||||||
|
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
result = await conn.execute(
|
||||||
|
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||||
|
{"name": expected_db},
|
||||||
|
)
|
||||||
|
assert result.scalar() is None
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
class TestCleanupTables:
|
||||||
|
"""Tests for cleanup_tables helper."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_truncates_all_tables(self):
|
||||||
|
"""All table rows are removed after cleanup_tables."""
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||||
|
role = Role(id=uuid.uuid4(), name="cleanup_role")
|
||||||
|
session.add(role)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
username="cleanup_user",
|
||||||
|
email="cleanup@test.com",
|
||||||
|
role_id=role.id,
|
||||||
|
)
|
||||||
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Verify rows exist
|
||||||
|
roles_count = await RoleCrud.count(session)
|
||||||
|
users_count = await UserCrud.count(session)
|
||||||
|
assert roles_count == 1
|
||||||
|
assert users_count == 1
|
||||||
|
|
||||||
|
await cleanup_tables(session, Base)
|
||||||
|
|
||||||
|
# Verify tables are empty
|
||||||
|
roles_count = await RoleCrud.count(session)
|
||||||
|
users_count = await UserCrud.count(session)
|
||||||
|
assert roles_count == 0
|
||||||
|
assert users_count == 0
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_noop_for_empty_metadata(self):
|
||||||
|
"""cleanup_tables does not raise when metadata has no tables."""
|
||||||
|
|
||||||
|
class EmptyBase(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||||
|
# Should not raise
|
||||||
|
await cleanup_tables(session, EmptyBase)
|
||||||
@@ -1,160 +0,0 @@
|
|||||||
"""Tests for fastapi_toolsets.pytest_plugin module."""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.orm import selectinload
|
|
||||||
|
|
||||||
from fastapi_toolsets.fixtures import Context, FixtureRegistry, register_fixtures
|
|
||||||
|
|
||||||
from .conftest import Role, RoleCrud, User, UserCrud
|
|
||||||
|
|
||||||
test_registry = FixtureRegistry()
|
|
||||||
|
|
||||||
|
|
||||||
@test_registry.register(contexts=[Context.BASE])
|
|
||||||
def roles() -> list[Role]:
|
|
||||||
return [
|
|
||||||
Role(id=1000, name="plugin_admin"),
|
|
||||||
Role(id=1001, name="plugin_user"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@test_registry.register(depends_on=["roles"], contexts=[Context.BASE])
|
|
||||||
def users() -> list[User]:
|
|
||||||
return [
|
|
||||||
User(id=1000, username="plugin_admin", email="padmin@test.com", role_id=1000),
|
|
||||||
User(id=1001, username="plugin_user", email="puser@test.com", role_id=1001),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@test_registry.register(depends_on=["users"], contexts=[Context.TESTING])
|
|
||||||
def extra_users() -> list[User]:
|
|
||||||
return [
|
|
||||||
User(id=1002, username="plugin_extra", email="pextra@test.com", role_id=1001),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
register_fixtures(test_registry, globals())
|
|
||||||
|
|
||||||
|
|
||||||
class TestRegisterFixtures:
|
|
||||||
"""Tests for register_fixtures function."""
|
|
||||||
|
|
||||||
def test_creates_fixtures_in_namespace(self):
|
|
||||||
"""Fixtures are created in the namespace."""
|
|
||||||
assert "fixture_roles" in globals()
|
|
||||||
assert "fixture_users" in globals()
|
|
||||||
assert "fixture_extra_users" in globals()
|
|
||||||
|
|
||||||
def test_fixtures_are_callable(self):
|
|
||||||
"""Created fixtures are callable."""
|
|
||||||
assert callable(globals()["fixture_roles"])
|
|
||||||
assert callable(globals()["fixture_users"])
|
|
||||||
|
|
||||||
|
|
||||||
class TestGeneratedFixtures:
|
|
||||||
"""Tests for the generated pytest fixtures."""
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_fixture_loads_data(
|
|
||||||
self, db_session: AsyncSession, fixture_roles: list[Role]
|
|
||||||
):
|
|
||||||
"""Fixture loads data into database and returns it."""
|
|
||||||
assert len(fixture_roles) == 2
|
|
||||||
assert fixture_roles[0].name == "plugin_admin"
|
|
||||||
assert fixture_roles[1].name == "plugin_user"
|
|
||||||
|
|
||||||
# Verify data is in database
|
|
||||||
count = await RoleCrud.count(db_session, [Role.id >= 1000])
|
|
||||||
assert count == 2
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_fixture_with_dependency(
|
|
||||||
self, db_session: AsyncSession, fixture_users: list[User]
|
|
||||||
):
|
|
||||||
"""Fixture with dependency loads parent fixture first."""
|
|
||||||
# fixture_users depends on fixture_roles
|
|
||||||
# Both should be loaded
|
|
||||||
assert len(fixture_users) == 2
|
|
||||||
|
|
||||||
# Roles should also be in database
|
|
||||||
roles_count = await RoleCrud.count(db_session, [Role.id >= 1000])
|
|
||||||
assert roles_count == 2
|
|
||||||
|
|
||||||
# Users should be in database
|
|
||||||
users_count = await UserCrud.count(db_session, [User.id >= 1000])
|
|
||||||
assert users_count == 2
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_fixture_returns_models(
|
|
||||||
self, db_session: AsyncSession, fixture_users: list[User]
|
|
||||||
):
|
|
||||||
"""Fixture returns actual model instances."""
|
|
||||||
user = fixture_users[0]
|
|
||||||
assert isinstance(user, User)
|
|
||||||
assert user.id == 1000
|
|
||||||
assert user.username == "plugin_admin"
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_fixture_relationships_work(
|
|
||||||
self, db_session: AsyncSession, fixture_users: list[User]
|
|
||||||
):
|
|
||||||
"""Loaded fixtures have working relationships."""
|
|
||||||
# Load user with role relationship
|
|
||||||
user = await UserCrud.get(
|
|
||||||
db_session,
|
|
||||||
[User.id == 1000],
|
|
||||||
load_options=[selectinload(User.role)],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert user.role is not None
|
|
||||||
assert user.role.name == "plugin_admin"
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_chained_dependencies(
|
|
||||||
self, db_session: AsyncSession, fixture_extra_users: list[User]
|
|
||||||
):
|
|
||||||
"""Chained dependencies are resolved correctly."""
|
|
||||||
# fixture_extra_users -> fixture_users -> fixture_roles
|
|
||||||
assert len(fixture_extra_users) == 1
|
|
||||||
|
|
||||||
# All fixtures should be loaded
|
|
||||||
roles_count = await RoleCrud.count(db_session, [Role.id >= 1000])
|
|
||||||
users_count = await UserCrud.count(db_session, [User.id >= 1000])
|
|
||||||
|
|
||||||
assert roles_count == 2
|
|
||||||
assert users_count == 3 # 2 from users + 1 from extra_users
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_can_query_loaded_data(
|
|
||||||
self, db_session: AsyncSession, fixture_users: list[User]
|
|
||||||
):
|
|
||||||
"""Can query the loaded fixture data."""
|
|
||||||
# Get all users loaded by fixture
|
|
||||||
users = await UserCrud.get_multi(
|
|
||||||
db_session,
|
|
||||||
filters=[User.id >= 1000],
|
|
||||||
order_by=User.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(users) == 2
|
|
||||||
assert users[0].username == "plugin_admin"
|
|
||||||
assert users[1].username == "plugin_user"
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_multiple_fixtures_in_same_test(
|
|
||||||
self,
|
|
||||||
db_session: AsyncSession,
|
|
||||||
fixture_roles: list[Role],
|
|
||||||
fixture_users: list[User],
|
|
||||||
):
|
|
||||||
"""Multiple fixtures can be used in the same test."""
|
|
||||||
assert len(fixture_roles) == 2
|
|
||||||
assert len(fixture_users) == 2
|
|
||||||
|
|
||||||
# Both should be in database
|
|
||||||
roles = await RoleCrud.get_multi(db_session, filters=[Role.id >= 1000])
|
|
||||||
users = await UserCrud.get_multi(db_session, filters=[User.id >= 1000])
|
|
||||||
|
|
||||||
assert len(roles) == 2
|
|
||||||
assert len(users) == 2
|
|
||||||
@@ -5,7 +5,9 @@ from pydantic import ValidationError
|
|||||||
|
|
||||||
from fastapi_toolsets.schemas import (
|
from fastapi_toolsets.schemas import (
|
||||||
ApiError,
|
ApiError,
|
||||||
|
CursorPagination,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
|
OffsetPagination,
|
||||||
PaginatedResponse,
|
PaginatedResponse,
|
||||||
Pagination,
|
Pagination,
|
||||||
Response,
|
Response,
|
||||||
@@ -46,6 +48,31 @@ class TestApiError:
|
|||||||
assert error.desc == "The resource was not found."
|
assert error.desc == "The resource was not found."
|
||||||
assert error.err_code == "RES-404"
|
assert error.err_code == "RES-404"
|
||||||
|
|
||||||
|
def test_data_defaults_to_none(self):
|
||||||
|
"""ApiError data field defaults to None."""
|
||||||
|
error = ApiError(
|
||||||
|
code=404,
|
||||||
|
msg="Not Found",
|
||||||
|
desc="The resource was not found.",
|
||||||
|
err_code="RES-404",
|
||||||
|
)
|
||||||
|
assert error.data is None
|
||||||
|
|
||||||
|
def test_create_with_data(self):
|
||||||
|
"""ApiError can be created with a data payload."""
|
||||||
|
error = ApiError(
|
||||||
|
code=422,
|
||||||
|
msg="Validation Error",
|
||||||
|
desc="2 validation error(s) detected",
|
||||||
|
err_code="VAL-422",
|
||||||
|
data={
|
||||||
|
"errors": [{"field": "name", "message": "required", "type": "missing"}]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert error.data == {
|
||||||
|
"errors": [{"field": "name", "message": "required", "type": "missing"}]
|
||||||
|
}
|
||||||
|
|
||||||
def test_requires_all_fields(self):
|
def test_requires_all_fields(self):
|
||||||
"""ApiError requires all fields."""
|
"""ApiError requires all fields."""
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
@@ -129,12 +156,12 @@ class TestErrorResponse:
|
|||||||
assert data["description"] == "Details"
|
assert data["description"] == "Details"
|
||||||
|
|
||||||
|
|
||||||
class TestPagination:
|
class TestOffsetPagination:
|
||||||
"""Tests for Pagination schema."""
|
"""Tests for OffsetPagination schema (canonical name for offset-based pagination)."""
|
||||||
|
|
||||||
def test_create_pagination(self):
|
def test_create_pagination(self):
|
||||||
"""Create Pagination with all fields."""
|
"""Create OffsetPagination with all fields."""
|
||||||
pagination = Pagination(
|
pagination = OffsetPagination(
|
||||||
total_count=100,
|
total_count=100,
|
||||||
items_per_page=10,
|
items_per_page=10,
|
||||||
page=1,
|
page=1,
|
||||||
@@ -148,7 +175,7 @@ class TestPagination:
|
|||||||
|
|
||||||
def test_last_page_has_more_false(self):
|
def test_last_page_has_more_false(self):
|
||||||
"""Last page has has_more=False."""
|
"""Last page has has_more=False."""
|
||||||
pagination = Pagination(
|
pagination = OffsetPagination(
|
||||||
total_count=25,
|
total_count=25,
|
||||||
items_per_page=10,
|
items_per_page=10,
|
||||||
page=3,
|
page=3,
|
||||||
@@ -158,8 +185,8 @@ class TestPagination:
|
|||||||
assert pagination.has_more is False
|
assert pagination.has_more is False
|
||||||
|
|
||||||
def test_serialization(self):
|
def test_serialization(self):
|
||||||
"""Pagination serializes correctly."""
|
"""OffsetPagination serializes correctly."""
|
||||||
pagination = Pagination(
|
pagination = OffsetPagination(
|
||||||
total_count=50,
|
total_count=50,
|
||||||
items_per_page=20,
|
items_per_page=20,
|
||||||
page=2,
|
page=2,
|
||||||
@@ -172,6 +199,77 @@ class TestPagination:
|
|||||||
assert data["page"] == 2
|
assert data["page"] == 2
|
||||||
assert data["has_more"] is True
|
assert data["has_more"] is True
|
||||||
|
|
||||||
|
def test_pagination_alias_is_offset_pagination(self):
|
||||||
|
"""Pagination is a backward-compatible alias for OffsetPagination."""
|
||||||
|
assert Pagination is OffsetPagination
|
||||||
|
|
||||||
|
def test_pagination_alias_constructs_offset_pagination(self):
|
||||||
|
"""Code using Pagination(...) still works unchanged."""
|
||||||
|
pagination = Pagination(
|
||||||
|
total_count=10,
|
||||||
|
items_per_page=5,
|
||||||
|
page=2,
|
||||||
|
has_more=False,
|
||||||
|
)
|
||||||
|
assert isinstance(pagination, OffsetPagination)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCursorPagination:
|
||||||
|
"""Tests for CursorPagination schema."""
|
||||||
|
|
||||||
|
def test_create_with_next_cursor(self):
|
||||||
|
"""CursorPagination with a next cursor indicates more pages."""
|
||||||
|
pagination = CursorPagination(
|
||||||
|
next_cursor="eyJ2YWx1ZSI6ICIxMjMifQ==",
|
||||||
|
items_per_page=20,
|
||||||
|
has_more=True,
|
||||||
|
)
|
||||||
|
assert pagination.next_cursor == "eyJ2YWx1ZSI6ICIxMjMifQ=="
|
||||||
|
assert pagination.prev_cursor is None
|
||||||
|
assert pagination.items_per_page == 20
|
||||||
|
assert pagination.has_more is True
|
||||||
|
|
||||||
|
def test_create_last_page(self):
|
||||||
|
"""CursorPagination for the last page has next_cursor=None and has_more=False."""
|
||||||
|
pagination = CursorPagination(
|
||||||
|
next_cursor=None,
|
||||||
|
items_per_page=20,
|
||||||
|
has_more=False,
|
||||||
|
)
|
||||||
|
assert pagination.next_cursor is None
|
||||||
|
assert pagination.has_more is False
|
||||||
|
|
||||||
|
def test_prev_cursor_defaults_to_none(self):
|
||||||
|
"""prev_cursor defaults to None."""
|
||||||
|
pagination = CursorPagination(
|
||||||
|
next_cursor=None, items_per_page=10, has_more=False
|
||||||
|
)
|
||||||
|
assert pagination.prev_cursor is None
|
||||||
|
|
||||||
|
def test_prev_cursor_can_be_set(self):
|
||||||
|
"""prev_cursor can be explicitly set."""
|
||||||
|
pagination = CursorPagination(
|
||||||
|
next_cursor="next123",
|
||||||
|
prev_cursor="prev456",
|
||||||
|
items_per_page=10,
|
||||||
|
has_more=True,
|
||||||
|
)
|
||||||
|
assert pagination.prev_cursor == "prev456"
|
||||||
|
|
||||||
|
def test_serialization(self):
|
||||||
|
"""CursorPagination serializes correctly."""
|
||||||
|
pagination = CursorPagination(
|
||||||
|
next_cursor="abc123",
|
||||||
|
prev_cursor="xyz789",
|
||||||
|
items_per_page=20,
|
||||||
|
has_more=True,
|
||||||
|
)
|
||||||
|
data = pagination.model_dump()
|
||||||
|
assert data["next_cursor"] == "abc123"
|
||||||
|
assert data["prev_cursor"] == "xyz789"
|
||||||
|
assert data["items_per_page"] == 20
|
||||||
|
assert data["has_more"] is True
|
||||||
|
|
||||||
|
|
||||||
class TestPaginatedResponse:
|
class TestPaginatedResponse:
|
||||||
"""Tests for PaginatedResponse schema."""
|
"""Tests for PaginatedResponse schema."""
|
||||||
@@ -189,6 +287,7 @@ class TestPaginatedResponse:
|
|||||||
pagination=pagination,
|
pagination=pagination,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert isinstance(response.pagination, OffsetPagination)
|
||||||
assert len(response.data) == 2
|
assert len(response.data) == 2
|
||||||
assert response.pagination.total_count == 30
|
assert response.pagination.total_count == 30
|
||||||
assert response.status == ResponseStatus.SUCCESS
|
assert response.status == ResponseStatus.SUCCESS
|
||||||
@@ -222,6 +321,7 @@ class TestPaginatedResponse:
|
|||||||
pagination=pagination,
|
pagination=pagination,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert isinstance(response.pagination, OffsetPagination)
|
||||||
assert response.data == []
|
assert response.data == []
|
||||||
assert response.pagination.total_count == 0
|
assert response.pagination.total_count == 0
|
||||||
|
|
||||||
@@ -265,6 +365,36 @@ class TestPaginatedResponse:
|
|||||||
assert data["data"] == ["item1", "item2"]
|
assert data["data"] == ["item1", "item2"]
|
||||||
assert data["pagination"]["page"] == 5
|
assert data["pagination"]["page"] == 5
|
||||||
|
|
||||||
|
def test_pagination_field_accepts_offset_pagination(self):
|
||||||
|
"""PaginatedResponse.pagination accepts OffsetPagination."""
|
||||||
|
response = PaginatedResponse(
|
||||||
|
data=[1, 2],
|
||||||
|
pagination=OffsetPagination(
|
||||||
|
total_count=2, items_per_page=10, page=1, has_more=False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert isinstance(response.pagination, OffsetPagination)
|
||||||
|
|
||||||
|
def test_pagination_field_accepts_cursor_pagination(self):
|
||||||
|
"""PaginatedResponse.pagination accepts CursorPagination."""
|
||||||
|
response = PaginatedResponse(
|
||||||
|
data=[1, 2],
|
||||||
|
pagination=CursorPagination(
|
||||||
|
next_cursor=None, items_per_page=10, has_more=False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert isinstance(response.pagination, CursorPagination)
|
||||||
|
|
||||||
|
def test_pagination_alias_accepted(self):
|
||||||
|
"""Constructing PaginatedResponse with Pagination (alias) still works."""
|
||||||
|
response = PaginatedResponse(
|
||||||
|
data=[],
|
||||||
|
pagination=Pagination(
|
||||||
|
total_count=0, items_per_page=10, page=1, has_more=False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert isinstance(response.pagination, OffsetPagination)
|
||||||
|
|
||||||
|
|
||||||
class TestFromAttributes:
|
class TestFromAttributes:
|
||||||
"""Tests for from_attributes config (ORM mode)."""
|
"""Tests for from_attributes config (ORM mode)."""
|
||||||
|
|||||||
318
zensical.toml
Normal file
318
zensical.toml
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
# ============================================================================
|
||||||
|
#
|
||||||
|
# The configuration produced by default is meant to highlight the features
|
||||||
|
# that Zensical provides and to serve as a starting point for your own
|
||||||
|
# projects.
|
||||||
|
#
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
[project]
|
||||||
|
|
||||||
|
# The site_name is shown in the page header and the browser window title
|
||||||
|
#
|
||||||
|
# Read more: https://zensical.org/docs/setup/basics/#site_name
|
||||||
|
site_name = "FastAPI Toolsets"
|
||||||
|
|
||||||
|
# The site_description is included in the HTML head and should contain a
|
||||||
|
# meaningful description of the site content for use by search engines.
|
||||||
|
#
|
||||||
|
# Read more: https://zensical.org/docs/setup/basics/#site_description
|
||||||
|
site_description = "Production-ready utilities for FastAPI applications."
|
||||||
|
|
||||||
|
# The site_author attribute. This is used in the HTML head element.
|
||||||
|
#
|
||||||
|
# Read more: https://zensical.org/docs/setup/basics/#site_author
|
||||||
|
site_author = "d3vyce"
|
||||||
|
|
||||||
|
# The site_url is the canonical URL for your site. When building online
|
||||||
|
# documentation you should set this.
|
||||||
|
# Read more: https://zensical.org/docs/setup/basics/#site_url
|
||||||
|
site_url = "https://fastapi-toolsets.d3vyce.fr"
|
||||||
|
|
||||||
|
# The copyright notice appears in the page footer and can contain an HTML
|
||||||
|
# fragment.
|
||||||
|
#
|
||||||
|
# Read more: https://zensical.org/docs/setup/basics/#copyright
|
||||||
|
copyright = """
|
||||||
|
Copyright © 2026 d3vyce
|
||||||
|
"""
|
||||||
|
|
||||||
|
repo_url = "https://github.com/d3vyce/fastapi-toolsets"
|
||||||
|
|
||||||
|
# Zensical supports both implicit navigation and explicitly defined navigation.
|
||||||
|
# If you decide not to define a navigation here then Zensical will simply
|
||||||
|
# derive the navigation structure from the directory structure of your
|
||||||
|
# "docs_dir". The definition below demonstrates how a navigation structure
|
||||||
|
# can be defined using TOML syntax.
|
||||||
|
#
|
||||||
|
# Read more: https://zensical.org/docs/setup/navigation/
|
||||||
|
# nav = [
|
||||||
|
# { "Get started" = "index.md" },
|
||||||
|
# { "Markdown in 5min" = "markdown.md" },
|
||||||
|
# ]
|
||||||
|
|
||||||
|
# With the "extra_css" option you can add your own CSS styling to customize
|
||||||
|
# your Zensical project according to your needs. You can add any number of
|
||||||
|
# CSS files.
|
||||||
|
#
|
||||||
|
# The path provided should be relative to the "docs_dir".
|
||||||
|
#
|
||||||
|
# Read more: https://zensical.org/docs/customization/#additional-css
|
||||||
|
#
|
||||||
|
#extra_css = ["stylesheets/extra.css"]
|
||||||
|
|
||||||
|
# With the `extra_javascript` option you can add your own JavaScript to your
|
||||||
|
# project to customize the behavior according to your needs.
|
||||||
|
#
|
||||||
|
# The path provided should be relative to the "docs_dir".
|
||||||
|
#
|
||||||
|
# Read more: https://zensical.org/docs/customization/#additional-javascript
|
||||||
|
#extra_javascript = ["javascripts/extra.js"]
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# Section for configuring theme options
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
[project.theme]
|
||||||
|
|
||||||
|
# change this to "classic" to use the traditional Material for MkDocs look.
|
||||||
|
#variant = "classic"
|
||||||
|
|
||||||
|
# Zensical allows you to override specific blocks, partials, or whole
|
||||||
|
# templates as well as to define your own templates. To do this, uncomment
|
||||||
|
# the custom_dir setting below and set it to a directory in which you
|
||||||
|
# keep your template overrides.
|
||||||
|
#
|
||||||
|
# Read more:
|
||||||
|
# - https://zensical.org/docs/customization/#extending-the-theme
|
||||||
|
#
|
||||||
|
custom_dir = "docs/overrides"
|
||||||
|
|
||||||
|
# With the "favicon" option you can set your own image to use as the icon
|
||||||
|
# browsers will use in the browser title bar or tab bar. The path provided
|
||||||
|
# must be relative to the "docs_dir".
|
||||||
|
#
|
||||||
|
# Read more:
|
||||||
|
# - https://zensical.org/docs/setup/logo-and-icons/#favicon
|
||||||
|
# - https://developer.mozilla.org/en-US/docs/Glossary/Favicon
|
||||||
|
#
|
||||||
|
#favicon = "images/favicon.png"
|
||||||
|
|
||||||
|
# Zensical supports more than 60 different languages. This means that the
|
||||||
|
# labels and tooltips that Zensical's templates produce are translated.
|
||||||
|
# The "language" option allows you to set the language used. This language
|
||||||
|
# is also indicated in the HTML head element to help with accessibility
|
||||||
|
# and guide search engines and translation tools.
|
||||||
|
#
|
||||||
|
# The default language is "en" (English). It is possible to create
|
||||||
|
# sites with multiple languages and configure a language selector. See
|
||||||
|
# the documentation for details.
|
||||||
|
#
|
||||||
|
# Read more:
|
||||||
|
# - https://zensical.org/docs/setup/language/
|
||||||
|
#
|
||||||
|
language = "en"
|
||||||
|
|
||||||
|
# Zensical provides a number of feature toggles that change the behavior
|
||||||
|
# of the documentation site.
|
||||||
|
features = [
|
||||||
|
# Zensical includes an announcement bar. This feature allows users to
|
||||||
|
# dismiss it when they have read the announcement.
|
||||||
|
# https://zensical.org/docs/setup/header/#announcement-bar
|
||||||
|
"announce.dismiss",
|
||||||
|
|
||||||
|
# If you have a repository configured and turn on this feature, Zensical
|
||||||
|
# will generate an edit button for the page. This works for common
|
||||||
|
# repository hosting services.
|
||||||
|
# https://zensical.org/docs/setup/repository/#content-actions
|
||||||
|
#"content.action.edit",
|
||||||
|
|
||||||
|
# If you have a repository configured and turn on this feature, Zensical
|
||||||
|
# will generate a button that allows the user to view the Markdown
|
||||||
|
# code for the current page.
|
||||||
|
# https://zensical.org/docs/setup/repository/#content-actions
|
||||||
|
"content.action.view",
|
||||||
|
|
||||||
|
# Code annotations allow you to add an icon with a tooltip to your
|
||||||
|
# code blocks to provide explanations at crucial points.
|
||||||
|
# https://zensical.org/docs/authoring/code-blocks/#code-annotations
|
||||||
|
"content.code.annotate",
|
||||||
|
|
||||||
|
# This feature turns on a button in code blocks that allow users to
|
||||||
|
# copy the content to their clipboard without first selecting it.
|
||||||
|
# https://zensical.org/docs/authoring/code-blocks/#code-copy-button
|
||||||
|
"content.code.copy",
|
||||||
|
|
||||||
|
# Code blocks can include a button to allow for the selection of line
|
||||||
|
# ranges by the user.
|
||||||
|
# https://zensical.org/docs/authoring/code-blocks/#code-selection-button
|
||||||
|
"content.code.select",
|
||||||
|
|
||||||
|
# Zensical can render footnotes as inline tooltips, so the user can read
|
||||||
|
# the footnote without leaving the context of the document.
|
||||||
|
# https://zensical.org/docs/authoring/footnotes/#footnote-tooltips
|
||||||
|
"content.footnote.tooltips",
|
||||||
|
|
||||||
|
# If you have many content tabs that have the same titles (e.g., "Python",
|
||||||
|
# "JavaScript", "Cobol"), this feature causes all of them to switch to
|
||||||
|
# at the same time when the user chooses their language in one.
|
||||||
|
# https://zensical.org/docs/authoring/content-tabs/#linked-content-tabs
|
||||||
|
"content.tabs.link",
|
||||||
|
|
||||||
|
# With this feature enabled users can add tooltips to links that will be
|
||||||
|
# displayed when the mouse pointer hovers the link.
|
||||||
|
# https://zensical.org/docs/authoring/tooltips/#improved-tooltips
|
||||||
|
"content.tooltips",
|
||||||
|
|
||||||
|
# With this feature enabled, Zensical will automatically hide parts
|
||||||
|
# of the header when the user scrolls past a certain point.
|
||||||
|
# https://zensical.org/docs/setup/header/#automatic-hiding
|
||||||
|
# "header.autohide",
|
||||||
|
|
||||||
|
# Turn on this feature to expand all collapsible sections in the
|
||||||
|
# navigation sidebar by default.
|
||||||
|
# https://zensical.org/docs/setup/navigation/#navigation-expansion
|
||||||
|
# "navigation.expand",
|
||||||
|
|
||||||
|
# This feature turns on navigation elements in the footer that allow the
|
||||||
|
# user to navigate to a next or previous page.
|
||||||
|
# https://zensical.org/docs/setup/footer/#navigation
|
||||||
|
"navigation.footer",
|
||||||
|
|
||||||
|
# When section index pages are enabled, documents can be directly attached
|
||||||
|
# to sections, which is particularly useful for providing overview pages.
|
||||||
|
# https://zensical.org/docs/setup/navigation/#section-index-pages
|
||||||
|
"navigation.indexes",
|
||||||
|
|
||||||
|
# When instant navigation is enabled, clicks on all internal links will be
|
||||||
|
# intercepted and dispatched via XHR without fully reloading the page.
|
||||||
|
# https://zensical.org/docs/setup/navigation/#instant-navigation
|
||||||
|
"navigation.instant",
|
||||||
|
|
||||||
|
# With instant prefetching, your site will start to fetch a page once the
|
||||||
|
# user hovers over a link. This will reduce the perceived loading time
|
||||||
|
# for the user.
|
||||||
|
# https://zensical.org/docs/setup/navigation/#instant-prefetching
|
||||||
|
"navigation.instant.prefetch",
|
||||||
|
|
||||||
|
# In order to provide a better user experience on slow connections when
|
||||||
|
# using instant navigation, a progress indicator can be enabled.
|
||||||
|
# https://zensical.org/docs/setup/navigation/#progress-indicator
|
||||||
|
#"navigation.instant.progress",
|
||||||
|
|
||||||
|
# When navigation paths are activated, a breadcrumb navigation is rendered
|
||||||
|
# above the title of each page
|
||||||
|
# https://zensical.org/docs/setup/navigation/#navigation-path
|
||||||
|
"navigation.path",
|
||||||
|
|
||||||
|
# When pruning is enabled, only the visible navigation items are included
|
||||||
|
# in the rendered HTML, reducing the size of the built site by 33% or more.
|
||||||
|
# https://zensical.org/docs/setup/navigation/#navigation-pruning
|
||||||
|
#"navigation.prune",
|
||||||
|
|
||||||
|
# When sections are enabled, top-level sections are rendered as groups in
|
||||||
|
# the sidebar for viewports above 1220px, but remain as-is on mobile.
|
||||||
|
# https://zensical.org/docs/setup/navigation/#navigation-sections
|
||||||
|
"navigation.sections",
|
||||||
|
|
||||||
|
# When tabs are enabled, top-level sections are rendered in a menu layer
|
||||||
|
# below the header for viewports above 1220px, but remain as-is on mobile.
|
||||||
|
# https://zensical.org/docs/setup/navigation/#navigation-tabs
|
||||||
|
"navigation.tabs",
|
||||||
|
|
||||||
|
# When sticky tabs are enabled, navigation tabs will lock below the header
|
||||||
|
# and always remain visible when scrolling down.
|
||||||
|
# https://zensical.org/docs/setup/navigation/#sticky-navigation-tabs
|
||||||
|
#"navigation.tabs.sticky",
|
||||||
|
|
||||||
|
# A back-to-top button can be shown when the user, after scrolling down,
|
||||||
|
# starts to scroll up again.
|
||||||
|
# https://zensical.org/docs/setup/navigation/#back-to-top-button
|
||||||
|
"navigation.top",
|
||||||
|
|
||||||
|
# When anchor tracking is enabled, the URL in the address bar is
|
||||||
|
# automatically updated with the active anchor as highlighted in the table
|
||||||
|
# of contents.
|
||||||
|
# https://zensical.org/docs/setup/navigation/#anchor-tracking
|
||||||
|
"navigation.tracking",
|
||||||
|
|
||||||
|
# When search highlighting is enabled and a user clicks on a search result,
|
||||||
|
# Zensical will highlight all occurrences after following the link.
|
||||||
|
# https://zensical.org/docs/setup/search/#search-highlighting
|
||||||
|
"search.highlight",
|
||||||
|
|
||||||
|
# When anchor following for the table of contents is enabled, the sidebar
|
||||||
|
# is automatically scrolled so that the active anchor is always visible.
|
||||||
|
# https://zensical.org/docs/setup/navigation/#anchor-following
|
||||||
|
# "toc.follow",
|
||||||
|
|
||||||
|
# When navigation integration for the table of contents is enabled, it is
|
||||||
|
# always rendered as part of the navigation sidebar on the left.
|
||||||
|
# https://zensical.org/docs/setup/navigation/#navigation-integration
|
||||||
|
#"toc.integrate",
|
||||||
|
]
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# In the "palette" subsection you can configure options for the color scheme.
|
||||||
|
# You can configure different color # schemes, e.g., to turn on dark mode,
|
||||||
|
# that the user can switch between. Each color scheme can be further
|
||||||
|
# customized.
|
||||||
|
#
|
||||||
|
# Read more:
|
||||||
|
# - https://zensical.org/docs/setup/colors/
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
[[project.theme.palette]]
|
||||||
|
scheme = "default"
|
||||||
|
toggle.icon = "lucide/sun"
|
||||||
|
toggle.name = "Switch to dark mode"
|
||||||
|
|
||||||
|
[[project.theme.palette]]
|
||||||
|
scheme = "slate"
|
||||||
|
toggle.icon = "lucide/moon"
|
||||||
|
toggle.name = "Switch to light mode"
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# In the "font" subsection you can configure the fonts used. By default, fonts
|
||||||
|
# are loaded from Google Fonts, giving you a wide range of choices from a set
|
||||||
|
# of suitably licensed fonts. There are options for a normal text font and for
|
||||||
|
# a monospaced font used in code blocks.
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
[project.theme.font]
|
||||||
|
text = "Inter"
|
||||||
|
code = "Jetbrains Mono"
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# You can configure your own logo to be shown in the header using the "logo"
|
||||||
|
# option in the "icons" subsection. The logo can be a path to a file in your
|
||||||
|
# "docs_dir" or it can be a path to an icon.
|
||||||
|
#
|
||||||
|
# Likewise, you can customize the logo used for the repository section of the
|
||||||
|
# header. Zensical derives the default logo for this from the repository URL.
|
||||||
|
# See below...
|
||||||
|
#
|
||||||
|
# There are other icons you can customize. See the documentation for details.
|
||||||
|
#
|
||||||
|
# Read more:
|
||||||
|
# - https://zensical.org/docs/setup/logo-and-icons
|
||||||
|
# - https://zensical.org/docs/authoring/icons-emojis/#search
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
[project.theme.icon]
|
||||||
|
#logo = "lucide/smile"
|
||||||
|
repo = "fontawesome/brands/github"
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# The "extra" section contains miscellaneous settings.
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
#[[project.extra.social]]
|
||||||
|
#icon = "fontawesome/brands/github"
|
||||||
|
#link = "https://github.com/user/repo"
|
||||||
|
|
||||||
|
|
||||||
|
[project.plugins.mkdocstrings.handlers.python]
|
||||||
|
inventories = ["https://docs.python.org/3/objects.inv"]
|
||||||
|
paths = ["src"]
|
||||||
|
|
||||||
|
[project.plugins.mkdocstrings.handlers.python.options]
|
||||||
|
docstring_style = "google"
|
||||||
|
inherited_members = true
|
||||||
|
show_source = false
|
||||||
|
show_root_heading = true
|
||||||
Reference in New Issue
Block a user